From 51df8e6269cb4201e08dec0f6d8647b2e57ebe94 Mon Sep 17 00:00:00 2001 From: Pierre-Edouard Portier Date: Wed, 3 Jan 2024 16:24:17 +0100 Subject: [PATCH] packaging of the embedding process --- embedding.py | 185 +++++++++++++++++++++++++++++++++++++++++ rag_fr_embedding.ipynb | 120 ++++++++++++++++++++++++++ 2 files changed, 305 insertions(+) create mode 100644 embedding.py create mode 100644 rag_fr_embedding.ipynb diff --git a/embedding.py b/embedding.py new file mode 100644 index 0000000..e5e491b --- /dev/null +++ b/embedding.py @@ -0,0 +1,185 @@ +from transformers import AutoTokenizer +from sentence_transformers import SentenceTransformer +import os +import re +import copy +import chromadb +import logging + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +class EmbeddingModel: + def __init__(self, model_name, chromadb_path, collection_name): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = SentenceTransformer(model_name) + self.chroma_client = chromadb.PersistentClient(path=chromadb_path) + self.collection = self.chroma_client.get_or_create_collection(name=collection_name) + + def token_length(self, text): + """ + Calculates the token length of a given text + + Args: + text (str): The text to be tokenized. + + Returns: + int: The number of tokens in the text. + + This function takes a string, tokenizes the string, and returns the number of tokens. + """ + return len(self.tokenizer.encode(text, add_special_tokens=False)) + + def passage_str(self, paragraphs, title): + """ + Constructs a passage string from given paragraphs and a title. + + Args: + paragraphs (list of str): A list of paragraphs. + title (str): The title of the passage. + + Returns: + str: A passage string that combines the title and paragraphs. + + This function takes a list of paragraphs and a title, and constructs a single string + with the title followed by the paragraphs, formatted for embedding. + """ + + return f"passage: {title}\n" + '\n'.join(paragraphs) + + def embed_page(self, filename, url, title, contents, tags, max_chunk_size=512): + """ + Embeds the text of a webpage into a ChromaDB collection. + + Args: + filename (str): The name of the file being processed. + url (str): The URL of the webpage. + title (str): The title of the webpage. + contents (list of str): The contents of the webpage, split into paragraphs. + tags (list of str): Tags derived from the URL. + max_chunk_size (int): The maximum token length for a chunk. Defaults to 512. + + Returns: + None + + This function divides the webpage content into chunks that fit within the max_chunk_size limit, + embeds each chunk using the provided model, and stores the embeddings in the specified ChromaDB collection. + """ + + documents = [] + contents_to_embed = [contents] + + while contents_to_embed: + last_item = contents_to_embed.pop() + # (1) For the `multilingual-e5-large` embedding model, + # the string of a document must be prepended with "passage:" + # (2) Since the text of a webpage may have to be cut into many documents, + # we always add the title of the webpage at the top of a document + last_item_str = self.passage_str(last_item, title) + last_item_token_length = self.token_length(last_item_str) + + if last_item_token_length > max_chunk_size: + # If the text of the webpage, present in file `filename`, + # contains more than `max_chunk_size` tokens, it must be divided + # into multiple documents + if len(last_item) > 1: + # If there are many paragraphs in `last_item`, i.e. the current + # part of the webpage for which an embedding will be made, + # the length of `last_item` can be reduced by dividing its set of + # paragraphs in half + h = len(last_item) // 2 + last_item_h1 = last_item[:h] + last_item_h2 = last_item[h:] + contents_to_embed.append(last_item_h1) + contents_to_embed.append(last_item_h2) + else: + # If `last_item` is made of only one long paragraph whose length is + # larger than `chunk_size`, this paragraph will be divided into two parts. + sentences = re.split(r'(?<=[.!?]) +', last_item[0]) + + if len(sentences) > 1: + # If there are multiple sentences, try to split into two parts + i = 1 + while True: + part1 = ' '.join(sentences[:i]) + part2 = ' '.join(sentences[i:]) + token_length_part_1 = self.token_length(self.passage_str([part1], title)) + token_length_part_2 = self.token_length(self.passage_str([part2], title)) + if (token_length_part_1 <= max_chunk_size and + token_length_part_2 <= max_chunk_size) or \ + token_length_part_1 > max_chunk_size: + break + i += 1 + else: + # If there's only one long sentence or no suitable split found, split by words + words = last_item[0].split() + h = len(words) // 2 + part1 = ' '.join(words[:h]) + part2 = ' '.join(words[h:]) + + contents_to_embed.append([part1]) + contents_to_embed.append([part2]) + else: + documents.append(last_item_str) + + # We want the documents into which a webpage has been divided + # to be in the natural reading order + documents.reverse() + embeddings = self.model.encode(documents, normalize_embeddings=True) + embeddings = embeddings.tolist() + + # We consider the subpart of an URL as tags describing the webpage + # For example, + # "https://www.caisse-epargne.fr/rhone-alpes/professionnels/financer-projets-optimiser-tresorerie/" + # is associated to the tags: + # tags[0] == 'rhone-alpes' + # tags[1] == 'professionnels' + # tags[2] == 'financer-projets-optimiser-tresorerie' + if len(tags) < 2: + category = '' + else: + if tags[0] == 'rhone-alpes': + category = tags[1] + else: category = tags[0] + metadata = {'category': category, 'url': url} + # All the documents corresponding to a same webpage have the same metadata, i.e. URL and category + metadatas = [copy.deepcopy(metadata) for _ in range(len(documents))] + + ids = [filename + '-' + str(i+1) for i in range(len(documents))] + + self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids) + + def embed_folder(self, folder_path): + """ + Embeds all the .txt files within a specified folder into a ChromaDB collection using a specified embedding model. + + Args: + folder_path (str): Path to the folder containing .txt files. + + Returns: + None + + This function processes each .txt file in the given folder, extracts the content, and uses `embed_page` + to embed the content into the specified ChromaDB collection. + """ + + for filename in os.listdir(folder_path): + if filename.endswith('.txt'): + file_path = os.path.join(folder_path, filename) + with open(file_path, 'r') as file: + file_contents = file.read() + contents_lst = [str.replace('\n',' ').replace('\xa0', ' ') for str in file_contents.split('\n\n')] + if len(contents_lst) < 3: # contents_lst[0] is the URL, contents_lst[1] is the title, the rest is the content + continue + url = contents_lst[0] + if '?' in url: # URLs with a '?' corresponds to call to services and have no useful content + continue + title = contents_lst[1] + if not title: # when the title is absent (or empty), the page has no interest + continue + logging.info(f"{filename} : Start") + prefix = 'https://www.caisse-epargne.fr/' + suffix = url.replace(prefix, '') + tags = suffix.split('/') + tags = [tag for tag in tags if tag] # remove empty parts + self.embed_page(filename, url, title, contents_lst[2:], tags) + logging.info(f"{filename} : Done") diff --git a/rag_fr_embedding.ipynb b/rag_fr_embedding.ipynb new file mode 100644 index 0000000..b4e9ef1 --- /dev/null +++ b/rag_fr_embedding.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "98de82f6-2dc9-4d27-a5d8-d07ae04b496c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " _torch_pytree._register_pytree_node(\n" + ] + } + ], + "source": [ + "from embedding import EmbeddingModel" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "37408a48-ce90-4176-bc9f-b71ebc22a178", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-01-03 11:13:53,279 - INFO - Load pretrained SentenceTransformer: intfloat/multilingual-e5-large\n", + "/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " _torch_pytree._register_pytree_node(\n", + "2024-01-03 11:13:56,891 - INFO - Use pytorch device: cpu\n", + "2024-01-03 11:13:56,894 - INFO - Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.\n", + "2024-01-03 11:13:56,990 - INFO - 4a06529f5f.txt : Start\n", + "Batches: 0%| | 0/1 [00:00 512). Running this sequence through the model will result in indexing errors\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.93it/s]\n", + "2024-01-03 11:13:58,189 - INFO - 4aac6081e0.txt : Done\n", + "2024-01-03 11:13:58,189 - INFO - 4a5736d002.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 5.89it/s]\n", + "2024-01-03 11:13:58,365 - INFO - 4a5736d002.txt : Done\n", + "2024-01-03 11:13:58,366 - INFO - 3d159cbe89.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.63it/s]\n", + "2024-01-03 11:13:58,988 - INFO - 3d159cbe89.txt : Done\n", + "2024-01-03 11:13:58,989 - INFO - 3f3e46760c.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 6.07it/s]\n", + "2024-01-03 11:13:59,159 - INFO - 3f3e46760c.txt : Done\n", + "2024-01-03 11:13:59,160 - INFO - 3ced86d1db.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 2.12it/s]\n", + "2024-01-03 11:13:59,640 - INFO - 3ced86d1db.txt : Done\n", + "2024-01-03 11:13:59,641 - INFO - 3bbe30b18a.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:01<00:00, 1.46s/it]\n", + "2024-01-03 11:14:01,116 - INFO - 3bbe30b18a.txt : Done\n", + "2024-01-03 11:14:01,116 - INFO - 3dbfdeb28e.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:01<00:00, 1.17s/it]\n", + "2024-01-03 11:14:02,299 - INFO - 3dbfdeb28e.txt : Done\n", + "2024-01-03 11:14:02,299 - INFO - 4adf02d48f.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.71it/s]\n", + "2024-01-03 11:14:02,895 - INFO - 4adf02d48f.txt : Done\n", + "2024-01-03 11:14:02,896 - INFO - 3c25273538.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:02<00:00, 2.02s/it]\n", + "2024-01-03 11:14:04,940 - INFO - 3c25273538.txt : Done\n", + "2024-01-03 11:14:04,940 - INFO - 4aeb967bdb.txt : Start\n", + "Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 2.00it/s]\n", + "2024-01-03 11:14:05,449 - INFO - 4aeb967bdb.txt : Done\n" + ] + } + ], + "source": [ + "model_name = 'intfloat/multilingual-e5-large'\n", + "chromadb_path = './chromadbtest'\n", + "folder_path = './docs/test'\n", + "collection_name = 'cera'\n", + "\n", + "embedding_model = EmbeddingModel(model_name, chromadb_path, collection_name)\n", + "embedding_model.embed_folder(folder_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2acd9c49-5676-4e72-9eff-f6fb8ffa94fe", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "RAG_ENV", + "language": "python", + "name": "rag_env" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}