packaging of the embedding process
This commit is contained in:
parent
34dfa18f6d
commit
51df8e6269
185
embedding.py
Normal file
185
embedding.py
Normal file
@ -0,0 +1,185 @@
|
||||
from transformers import AutoTokenizer
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import os
|
||||
import re
|
||||
import copy
|
||||
import chromadb
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
class EmbeddingModel:
|
||||
def __init__(self, model_name, chromadb_path, collection_name):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
|
||||
|
||||
def token_length(self, text):
|
||||
"""
|
||||
Calculates the token length of a given text
|
||||
|
||||
Args:
|
||||
text (str): The text to be tokenized.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens in the text.
|
||||
|
||||
This function takes a string, tokenizes the string, and returns the number of tokens.
|
||||
"""
|
||||
return len(self.tokenizer.encode(text, add_special_tokens=False))
|
||||
|
||||
def passage_str(self, paragraphs, title):
|
||||
"""
|
||||
Constructs a passage string from given paragraphs and a title.
|
||||
|
||||
Args:
|
||||
paragraphs (list of str): A list of paragraphs.
|
||||
title (str): The title of the passage.
|
||||
|
||||
Returns:
|
||||
str: A passage string that combines the title and paragraphs.
|
||||
|
||||
This function takes a list of paragraphs and a title, and constructs a single string
|
||||
with the title followed by the paragraphs, formatted for embedding.
|
||||
"""
|
||||
|
||||
return f"passage: {title}\n" + '\n'.join(paragraphs)
|
||||
|
||||
def embed_page(self, filename, url, title, contents, tags, max_chunk_size=512):
|
||||
"""
|
||||
Embeds the text of a webpage into a ChromaDB collection.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file being processed.
|
||||
url (str): The URL of the webpage.
|
||||
title (str): The title of the webpage.
|
||||
contents (list of str): The contents of the webpage, split into paragraphs.
|
||||
tags (list of str): Tags derived from the URL.
|
||||
max_chunk_size (int): The maximum token length for a chunk. Defaults to 512.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
This function divides the webpage content into chunks that fit within the max_chunk_size limit,
|
||||
embeds each chunk using the provided model, and stores the embeddings in the specified ChromaDB collection.
|
||||
"""
|
||||
|
||||
documents = []
|
||||
contents_to_embed = [contents]
|
||||
|
||||
while contents_to_embed:
|
||||
last_item = contents_to_embed.pop()
|
||||
# (1) For the `multilingual-e5-large` embedding model,
|
||||
# the string of a document must be prepended with "passage:"
|
||||
# (2) Since the text of a webpage may have to be cut into many documents,
|
||||
# we always add the title of the webpage at the top of a document
|
||||
last_item_str = self.passage_str(last_item, title)
|
||||
last_item_token_length = self.token_length(last_item_str)
|
||||
|
||||
if last_item_token_length > max_chunk_size:
|
||||
# If the text of the webpage, present in file `filename`,
|
||||
# contains more than `max_chunk_size` tokens, it must be divided
|
||||
# into multiple documents
|
||||
if len(last_item) > 1:
|
||||
# If there are many paragraphs in `last_item`, i.e. the current
|
||||
# part of the webpage for which an embedding will be made,
|
||||
# the length of `last_item` can be reduced by dividing its set of
|
||||
# paragraphs in half
|
||||
h = len(last_item) // 2
|
||||
last_item_h1 = last_item[:h]
|
||||
last_item_h2 = last_item[h:]
|
||||
contents_to_embed.append(last_item_h1)
|
||||
contents_to_embed.append(last_item_h2)
|
||||
else:
|
||||
# If `last_item` is made of only one long paragraph whose length is
|
||||
# larger than `chunk_size`, this paragraph will be divided into two parts.
|
||||
sentences = re.split(r'(?<=[.!?]) +', last_item[0])
|
||||
|
||||
if len(sentences) > 1:
|
||||
# If there are multiple sentences, try to split into two parts
|
||||
i = 1
|
||||
while True:
|
||||
part1 = ' '.join(sentences[:i])
|
||||
part2 = ' '.join(sentences[i:])
|
||||
token_length_part_1 = self.token_length(self.passage_str([part1], title))
|
||||
token_length_part_2 = self.token_length(self.passage_str([part2], title))
|
||||
if (token_length_part_1 <= max_chunk_size and
|
||||
token_length_part_2 <= max_chunk_size) or \
|
||||
token_length_part_1 > max_chunk_size:
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
# If there's only one long sentence or no suitable split found, split by words
|
||||
words = last_item[0].split()
|
||||
h = len(words) // 2
|
||||
part1 = ' '.join(words[:h])
|
||||
part2 = ' '.join(words[h:])
|
||||
|
||||
contents_to_embed.append([part1])
|
||||
contents_to_embed.append([part2])
|
||||
else:
|
||||
documents.append(last_item_str)
|
||||
|
||||
# We want the documents into which a webpage has been divided
|
||||
# to be in the natural reading order
|
||||
documents.reverse()
|
||||
embeddings = self.model.encode(documents, normalize_embeddings=True)
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
# We consider the subpart of an URL as tags describing the webpage
|
||||
# For example,
|
||||
# "https://www.caisse-epargne.fr/rhone-alpes/professionnels/financer-projets-optimiser-tresorerie/"
|
||||
# is associated to the tags:
|
||||
# tags[0] == 'rhone-alpes'
|
||||
# tags[1] == 'professionnels'
|
||||
# tags[2] == 'financer-projets-optimiser-tresorerie'
|
||||
if len(tags) < 2:
|
||||
category = ''
|
||||
else:
|
||||
if tags[0] == 'rhone-alpes':
|
||||
category = tags[1]
|
||||
else: category = tags[0]
|
||||
metadata = {'category': category, 'url': url}
|
||||
# All the documents corresponding to a same webpage have the same metadata, i.e. URL and category
|
||||
metadatas = [copy.deepcopy(metadata) for _ in range(len(documents))]
|
||||
|
||||
ids = [filename + '-' + str(i+1) for i in range(len(documents))]
|
||||
|
||||
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
|
||||
|
||||
def embed_folder(self, folder_path):
|
||||
"""
|
||||
Embeds all the .txt files within a specified folder into a ChromaDB collection using a specified embedding model.
|
||||
|
||||
Args:
|
||||
folder_path (str): Path to the folder containing .txt files.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
This function processes each .txt file in the given folder, extracts the content, and uses `embed_page`
|
||||
to embed the content into the specified ChromaDB collection.
|
||||
"""
|
||||
|
||||
for filename in os.listdir(folder_path):
|
||||
if filename.endswith('.txt'):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
with open(file_path, 'r') as file:
|
||||
file_contents = file.read()
|
||||
contents_lst = [str.replace('\n',' ').replace('\xa0', ' ') for str in file_contents.split('\n\n')]
|
||||
if len(contents_lst) < 3: # contents_lst[0] is the URL, contents_lst[1] is the title, the rest is the content
|
||||
continue
|
||||
url = contents_lst[0]
|
||||
if '?' in url: # URLs with a '?' corresponds to call to services and have no useful content
|
||||
continue
|
||||
title = contents_lst[1]
|
||||
if not title: # when the title is absent (or empty), the page has no interest
|
||||
continue
|
||||
logging.info(f"{filename} : Start")
|
||||
prefix = 'https://www.caisse-epargne.fr/'
|
||||
suffix = url.replace(prefix, '')
|
||||
tags = suffix.split('/')
|
||||
tags = [tag for tag in tags if tag] # remove empty parts
|
||||
self.embed_page(filename, url, title, contents_lst[2:], tags)
|
||||
logging.info(f"{filename} : Done")
|
120
rag_fr_embedding.ipynb
Normal file
120
rag_fr_embedding.ipynb
Normal file
@ -0,0 +1,120 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "98de82f6-2dc9-4d27-a5d8-d07ae04b496c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
|
||||
" _torch_pytree._register_pytree_node(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from embedding import EmbeddingModel"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "37408a48-ce90-4176-bc9f-b71ebc22a178",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-01-03 11:13:53,279 - INFO - Load pretrained SentenceTransformer: intfloat/multilingual-e5-large\n",
|
||||
"/Users/peportier/miniforge3/envs/RAG_ENV/lib/python3.9/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
|
||||
" _torch_pytree._register_pytree_node(\n",
|
||||
"2024-01-03 11:13:56,891 - INFO - Use pytorch device: cpu\n",
|
||||
"2024-01-03 11:13:56,894 - INFO - Anonymized telemetry enabled. See https://docs.trychroma.com/telemetry for more information.\n",
|
||||
"2024-01-03 11:13:56,990 - INFO - 4a06529f5f.txt : Start\n",
|
||||
"Batches: 0%| | 0/1 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||
"To disable this warning, you can either:\n",
|
||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.64it/s]\n",
|
||||
"2024-01-03 11:13:57,660 - INFO - 4a06529f5f.txt : Done\n",
|
||||
"2024-01-03 11:13:57,660 - INFO - 4aac6081e0.txt : Start\n",
|
||||
"Token indices sequence length is longer than the specified maximum sequence length for this model (595 > 512). Running this sequence through the model will result in indexing errors\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.93it/s]\n",
|
||||
"2024-01-03 11:13:58,189 - INFO - 4aac6081e0.txt : Done\n",
|
||||
"2024-01-03 11:13:58,189 - INFO - 4a5736d002.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 5.89it/s]\n",
|
||||
"2024-01-03 11:13:58,365 - INFO - 4a5736d002.txt : Done\n",
|
||||
"2024-01-03 11:13:58,366 - INFO - 3d159cbe89.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.63it/s]\n",
|
||||
"2024-01-03 11:13:58,988 - INFO - 3d159cbe89.txt : Done\n",
|
||||
"2024-01-03 11:13:58,989 - INFO - 3f3e46760c.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 6.07it/s]\n",
|
||||
"2024-01-03 11:13:59,159 - INFO - 3f3e46760c.txt : Done\n",
|
||||
"2024-01-03 11:13:59,160 - INFO - 3ced86d1db.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 2.12it/s]\n",
|
||||
"2024-01-03 11:13:59,640 - INFO - 3ced86d1db.txt : Done\n",
|
||||
"2024-01-03 11:13:59,641 - INFO - 3bbe30b18a.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:01<00:00, 1.46s/it]\n",
|
||||
"2024-01-03 11:14:01,116 - INFO - 3bbe30b18a.txt : Done\n",
|
||||
"2024-01-03 11:14:01,116 - INFO - 3dbfdeb28e.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:01<00:00, 1.17s/it]\n",
|
||||
"2024-01-03 11:14:02,299 - INFO - 3dbfdeb28e.txt : Done\n",
|
||||
"2024-01-03 11:14:02,299 - INFO - 4adf02d48f.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 1.71it/s]\n",
|
||||
"2024-01-03 11:14:02,895 - INFO - 4adf02d48f.txt : Done\n",
|
||||
"2024-01-03 11:14:02,896 - INFO - 3c25273538.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:02<00:00, 2.02s/it]\n",
|
||||
"2024-01-03 11:14:04,940 - INFO - 3c25273538.txt : Done\n",
|
||||
"2024-01-03 11:14:04,940 - INFO - 4aeb967bdb.txt : Start\n",
|
||||
"Batches: 100%|████████████████████████████████████| 1/1 [00:00<00:00, 2.00it/s]\n",
|
||||
"2024-01-03 11:14:05,449 - INFO - 4aeb967bdb.txt : Done\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_name = 'intfloat/multilingual-e5-large'\n",
|
||||
"chromadb_path = './chromadbtest'\n",
|
||||
"folder_path = './docs/test'\n",
|
||||
"collection_name = 'cera'\n",
|
||||
"\n",
|
||||
"embedding_model = EmbeddingModel(model_name, chromadb_path, collection_name)\n",
|
||||
"embedding_model.embed_folder(folder_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2acd9c49-5676-4e72-9eff-f6fb8ffa94fe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "RAG_ENV",
|
||||
"language": "python",
|
||||
"name": "rag_env"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Loading…
Reference in New Issue
Block a user