213 lines
9.5 KiB

from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import os
import re
import copy
import chromadb
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class EmbeddingModel:
def __init__(self, model_name, chromadb_path, collection_name):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = SentenceTransformer(model_name)
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
def token_length(self, text):
Calculates the token length of a given text
text (str): The text to be tokenized.
int: The number of tokens in the text.
This function takes a string, tokenizes the string, and returns the number of tokens.
return len(self.tokenizer.encode(text, add_special_tokens=False))
def passage_str(self, paragraphs, title):
Constructs a passage string from given paragraphs and a title.
paragraphs (list of str): A list of paragraphs.
title (str): The title of the passage.
str: A passage string that combines the title and paragraphs.
This function takes a list of paragraphs and a title, and constructs a single string
with the title followed by the paragraphs, formatted for embedding.
return f"passage: {title}\n" + '\n'.join(paragraphs)
def embed_page(self, filename, url, title, contents, tags, max_chunk_size=512):
Embeds the text of a webpage into a ChromaDB collection.
filename (str): The name of the file being processed.
url (str): The URL of the webpage.
title (str): The title of the webpage.
contents (list of str): The contents of the webpage, split into paragraphs.
tags (list of str): Tags derived from the URL.
max_chunk_size (int): The maximum token length for a chunk. Defaults to 512.
This function divides the webpage content into chunks that fit within the max_chunk_size limit,
embeds each chunk using the provided model, and stores the embeddings in the specified ChromaDB collection.
documents = []
contents_to_embed = [contents]
while contents_to_embed:
last_item = contents_to_embed.pop()
# (1) For the `multilingual-e5-large` embedding model,
# the string of a document must be prepended with "passage:"
# (2) Since the text of a webpage may have to be cut into many documents,
# we always add the title of the webpage at the top of a document
last_item_str = self.passage_str(last_item, title)
last_item_token_length = self.token_length(last_item_str)
if last_item_token_length > max_chunk_size:
# If the text of the webpage, present in file `filename`,
# contains more than `max_chunk_size` tokens, it must be divided
# into multiple documents
if len(last_item) > 1:
# If there are many paragraphs in `last_item`, i.e. the current
# part of the webpage for which an embedding will be made,
# the length of `last_item` can be reduced by dividing its set of
# paragraphs in half
h = len(last_item) // 2
last_item_h1 = last_item[:h]
last_item_h2 = last_item[h:]
# If `last_item` is made of only one long paragraph whose length is
# larger than `chunk_size`, this paragraph will be divided into two parts.
sentences = re.split(r'(?<=[.!?]) +', last_item[0])
if len(sentences) > 1:
# If there are multiple sentences, try to split into two parts
i = 1
while True:
part1 = ' '.join(sentences[:i])
part2 = ' '.join(sentences[i:])
token_length_part_1 = self.token_length(self.passage_str([part1], title))
token_length_part_2 = self.token_length(self.passage_str([part2], title))
if (token_length_part_1 <= max_chunk_size and
token_length_part_2 <= max_chunk_size) or \
token_length_part_1 > max_chunk_size:
i += 1
# If there's only one long sentence or no suitable split found, split by words
words = last_item[0].split()
h = len(words) // 2
part1 = ' '.join(words[:h])
part2 = ' '.join(words[h:])
# We want the documents into which a webpage has been divided
# to be in the natural reading order
embeddings = self.model.encode(documents, normalize_embeddings=True)
embeddings = embeddings.tolist()
# We consider the subpart of an URL as tags describing the webpage
# For example,
# "https://www.caisse-epargne.fr/rhone-alpes/professionnels/financer-projets-optimiser-tresorerie/"
# is associated to the tags:
# tags[0] == 'rhone-alpes'
# tags[1] == 'professionnels'
# tags[2] == 'financer-projets-optimiser-tresorerie'
if len(tags) < 2:
category = ''
if tags[0] == 'rhone-alpes':
category = tags[1]
else: category = tags[0]
metadata = {'category': category, 'url': url}
# All the documents corresponding to a same webpage have the same metadata, i.e. URL and category
metadatas = [copy.deepcopy(metadata) for _ in range(len(documents))]
ids = [filename + '-' + str(i+1) for i in range(len(documents))]
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
def remove_duplicate(self, lst):
# file_contents can contain duplicate lines
# because we keep the textual content of multiple html tags that can be embedded one in another
i = 0
while i < len(lst) - 1:
if i < len(lst) - 3 and lst[i] == lst[i + 2] and lst[i + 1] == lst[i + 3] == '':
# Remove lst[i+1], lst[i+2], and lst[i+3]
del lst[i + 1:i + 3]
elif lst[i] == lst[i + 1]:
# Remove lst[i+1]
del lst[i + 1]
i += 1
return lst
def remove_footer(self, lst):
sequence = ["Caisse d'Epargne", "Rhône Alpes", "Formuler une demande en ligne"]
for i in range(len(lst) - 2):
if lst[i:i + 3] == sequence:
del lst[i:]
return lst
def embed_folder(self, folder_path):
Embeds all the .txt files within a specified folder into a ChromaDB collection using a specified embedding model.
folder_path (str): Path to the folder containing .txt files.
This function processes each .txt file in the given folder, extracts the content, and uses `embed_page`
to embed the content into the specified ChromaDB collection.
for filename in os.listdir(folder_path):
if filename.endswith('.txt'):
file_path = os.path.join(folder_path, filename)
with open(file_path, 'r') as file:
file_contents = file.read()
file_contents = '\n'.join(self.remove_duplicate(file_contents.split('\n')))
contents_lst = [str.replace('\n',' ').replace('\xa0', ' ') for str in file_contents.split('\n\n')]
contents_lst = self.remove_footer(contents_lst)
if len(contents_lst) < 3: # contents_lst[0] is the URL, contents_lst[1] is the title, the rest is the content
url = contents_lst[0]
if '?' in url: # URLs with a '?' corresponds to call to services and have no useful content
if not url.startswith('https://www.caisse-epargne.fr/rhone-alpes/'):
title = contents_lst[1]
if not title: # when the title is absent (or empty), the page has no interest
logging.info(f"{filename} : Start")
prefix = 'https://www.caisse-epargne.fr/'
suffix = url.replace(prefix, '')
tags = suffix.split('/')
tags = [tag for tag in tags if tag] # remove empty parts
self.embed_page(filename, url, title, contents_lst[2:], tags)
logging.info(f"{filename} : Done")