packaging of the embedding process

This commit is contained in:
Pierre-Edouard Portier 2024-01-03 16:24:17 +01:00
parent 34dfa18f6d
commit 51df8e6269
2 changed files with 305 additions and 0 deletions

185
embedding.py Normal file
View File

@ -0,0 +1,185 @@
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import os
import re
import copy
import chromadb
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class EmbeddingModel:
def __init__(self, model_name, chromadb_path, collection_name):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = SentenceTransformer(model_name)
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
def token_length(self, text):
"""
Calculates the token length of a given text
Args:
text (str): The text to be tokenized.
Returns:
int: The number of tokens in the text.
This function takes a string, tokenizes the string, and returns the number of tokens.
"""
return len(self.tokenizer.encode(text, add_special_tokens=False))
def passage_str(self, paragraphs, title):
"""
Constructs a passage string from given paragraphs and a title.
Args:
paragraphs (list of str): A list of paragraphs.
title (str): The title of the passage.
Returns:
str: A passage string that combines the title and paragraphs.
This function takes a list of paragraphs and a title, and constructs a single string
with the title followed by the paragraphs, formatted for embedding.
"""
return f"passage: {title}\n" + '\n'.join(paragraphs)
def embed_page(self, filename, url, title, contents, tags, max_chunk_size=512):
"""
Embeds the text of a webpage into a ChromaDB collection.
Args:
filename (str): The name of the file being processed.
url (str): The URL of the webpage.
title (str): The title of the webpage.
contents (list of str): The contents of the webpage, split into paragraphs.
tags (list of str): Tags derived from the URL.
max_chunk_size (int): The maximum token length for a chunk. Defaults to 512.
Returns:
None
This function divides the webpage content into chunks that fit within the max_chunk_size limit,
embeds each chunk using the provided model, and stores the embeddings in the specified ChromaDB collection.
"""
documents = []
contents_to_embed = [contents]
while contents_to_embed:
last_item = contents_to_embed.pop()
# (1) For the `multilingual-e5-large` embedding model,
# the string of a document must be prepended with "passage:"
# (2) Since the text of a webpage may have to be cut into many documents,
# we always add the title of the webpage at the top of a document
last_item_str = self.passage_str(last_item, title)
last_item_token_length = self.token_length(last_item_str)
if last_item_token_length > max_chunk_size:
# If the text of the webpage, present in file `filename`,
# contains more than `max_chunk_size` tokens, it must be divided
# into multiple documents
if len(last_item) > 1:
# If there are many paragraphs in `last_item`, i.e. the current
# part of the webpage for which an embedding will be made,
# the length of `last_item` can be reduced by dividing its set of
# paragraphs in half
h = len(last_item) // 2
last_item_h1 = last_item[:h]
last_item_h2 = last_item[h:]
contents_to_embed.append(last_item_h1)
contents_to_embed.append(last_item_h2)
else:
# If `last_item` is made of only one long paragraph whose length is
# larger than `chunk_size`, this paragraph will be divided into two parts.
sentences = re.split(r'(?<=[.!?]) +', last_item[0])
if len(sentences) > 1:
# If there are multiple sentences, try to split into two parts
i = 1
while True:
part1 = ' '.join(sentences[:i])
part2 = ' '.join(sentences[i:])
token_length_part_1 = self.token_length(self.passage_str([part1], title))
token_length_part_2 = self.token_length(self.passage_str([part2], title))
if (token_length_part_1 <= max_chunk_size and
token_length_part_2 <= max_chunk_size) or \
token_length_part_1 > max_chunk_size:
break
i += 1
else:
# If there's only one long sentence or no suitable split found, split by words
words = last_item[0].split()
h = len(words) // 2
part1 = ' '.join(words[:h])
part2 = ' '.join(words[h:])
contents_to_embed.append([part1])
contents_to_embed.append([part2])
else:
documents.append(last_item_str)
# We want the documents into which a webpage has been divided
# to be in the natural reading order
documents.reverse()
embeddings = self.model.encode(documents, normalize_embeddings=True)
embeddings = embeddings.tolist()
# We consider the subpart of an URL as tags describing the webpage
# For example,
# "https://www.caisse-epargne.fr/rhone-alpes/professionnels/financer-projets-optimiser-tresorerie/"
# is associated to the tags:
# tags[0] == 'rhone-alpes'
# tags[1] == 'professionnels'
# tags[2] == 'financer-projets-optimiser-tresorerie'
if len(tags) < 2:
category = ''
else:
if tags[0] == 'rhone-alpes':
category = tags[1]
else: category = tags[0]
metadata = {'category': category, 'url': url}
# All the documents corresponding to a same webpage have the same metadata, i.e. URL and category
metadatas = [copy.deepcopy(metadata) for _ in range(len(documents))]
ids = [filename + '-' + str(i+1) for i in range(len(documents))]
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
def embed_folder(self, folder_path):
"""
Embeds all the .txt files within a specified folder into a ChromaDB collection using a specified embedding model.
Args:
folder_path (str): Path to the folder containing .txt files.
Returns:
None
This function processes each .txt file in the given folder, extracts the content, and uses `embed_page`
to embed the content into the specified ChromaDB collection.
"""
for filename in os.listdir(folder_path):
if filename.endswith('.txt'):
file_path = os.path.join(folder_path, filename)
with open(file_path, 'r') as file:
file_contents = file.read()
contents_lst = [str.replace('\n',' ').replace('\xa0', ' ') for str in file_contents.split('\n\n')]
if len(contents_lst) < 3: # contents_lst[0] is the URL, contents_lst[1] is the title, the rest is the content
continue
url = contents_lst[0]
if '?' in url: # URLs with a '?' corresponds to call to services and have no useful content
continue
title = contents_lst[1]
if not title: # when the title is absent (or empty), the page has no interest
continue
logging.info(f"{filename} : Start")
prefix = 'https://www.caisse-epargne.fr/'
suffix = url.replace(prefix, '')
tags = suffix.split('/')
tags = [tag for tag in tags if tag] # remove empty parts
self.embed_page(filename, url, title, contents_lst[2:], tags)
logging.info(f"{filename} : Done")

120
rag_fr_embedding.ipynb Normal file
View File

@ -0,0 +1,120 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "98de82f6-2dc9-4d27-a5d8-d07ae04b496c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n"
]
}
],
"source": [
"from embedding import EmbeddingModel"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "37408a48-ce90-4176-bc9f-b71ebc22a178",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-01-03 11:13:53,279 - INFO - Load pretrained SentenceTransformer: intfloat/multilingual-e5-large\n",
"/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n",
"2024-01-03 11:13:56,891 - INFO - Use pytorch device: cpu\n",
"2024-01-03 11:13:56,894 - INFO - Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.\n",
"2024-01-03 11:13:56,990 - INFO - 4a06529f5f.txt : Start\n",
"Batches: 0%| | 0/1 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.64it/s]\n",
"2024-01-03 11:13:57,660 - INFO - 4a06529f5f.txt : Done\n",
"2024-01-03 11:13:57,660 - INFO - 4aac6081e0.txt : Start\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (595 > 512). Running this sequence through the model will result in indexing errors\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.93it/s]\n",
"2024-01-03 11:13:58,189 - INFO - 4aac6081e0.txt : Done\n",
"2024-01-03 11:13:58,189 - INFO - 4a5736d002.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 5.89it/s]\n",
"2024-01-03 11:13:58,365 - INFO - 4a5736d002.txt : Done\n",
"2024-01-03 11:13:58,366 - INFO - 3d159cbe89.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.63it/s]\n",
"2024-01-03 11:13:58,988 - INFO - 3d159cbe89.txt : Done\n",
"2024-01-03 11:13:58,989 - INFO - 3f3e46760c.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 6.07it/s]\n",
"2024-01-03 11:13:59,159 - INFO - 3f3e46760c.txt : Done\n",
"2024-01-03 11:13:59,160 - INFO - 3ced86d1db.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 2.12it/s]\n",
"2024-01-03 11:13:59,640 - INFO - 3ced86d1db.txt : Done\n",
"2024-01-03 11:13:59,641 - INFO - 3bbe30b18a.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:01<00:00, 1.46s/it]\n",
"2024-01-03 11:14:01,116 - INFO - 3bbe30b18a.txt : Done\n",
"2024-01-03 11:14:01,116 - INFO - 3dbfdeb28e.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:01<00:00, 1.17s/it]\n",
"2024-01-03 11:14:02,299 - INFO - 3dbfdeb28e.txt : Done\n",
"2024-01-03 11:14:02,299 - INFO - 4adf02d48f.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.71it/s]\n",
"2024-01-03 11:14:02,895 - INFO - 4adf02d48f.txt : Done\n",
"2024-01-03 11:14:02,896 - INFO - 3c25273538.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:02<00:00, 2.02s/it]\n",
"2024-01-03 11:14:04,940 - INFO - 3c25273538.txt : Done\n",
"2024-01-03 11:14:04,940 - INFO - 4aeb967bdb.txt : Start\n",
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 2.00it/s]\n",
"2024-01-03 11:14:05,449 - INFO - 4aeb967bdb.txt : Done\n"
]
}
],
"source": [
"model_name = 'intfloat/multilingual-e5-large'\n",
"chromadb_path = './chromadbtest'\n",
"folder_path = './docs/test'\n",
"collection_name = 'cera'\n",
"\n",
"embedding_model = EmbeddingModel(model_name, chromadb_path, collection_name)\n",
"embedding_model.embed_folder(folder_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2acd9c49-5676-4e72-9eff-f6fb8ffa94fe",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "RAG_ENV",
"language": "python",
"name": "rag_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}