2024-01-03 10:24:17 -05:00
|
|
|
from transformers import AutoTokenizer
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import copy
|
|
|
|
import chromadb
|
|
|
|
import logging
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
class EmbeddingModel:
|
|
|
|
def __init__(self, model_name, chromadb_path, collection_name):
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
self.model = SentenceTransformer(model_name)
|
|
|
|
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
|
|
|
self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
|
|
|
|
|
|
|
|
def token_length(self, text):
|
|
|
|
"""
|
|
|
|
Calculates the token length of a given text
|
|
|
|
|
|
|
|
Args:
|
|
|
|
text (str): The text to be tokenized.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: The number of tokens in the text.
|
|
|
|
|
|
|
|
This function takes a string, tokenizes the string, and returns the number of tokens.
|
|
|
|
"""
|
|
|
|
return len(self.tokenizer.encode(text, add_special_tokens=False))
|
|
|
|
|
|
|
|
def passage_str(self, paragraphs, title):
|
|
|
|
"""
|
|
|
|
Constructs a passage string from given paragraphs and a title.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
paragraphs (list of str): A list of paragraphs.
|
|
|
|
title (str): The title of the passage.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: A passage string that combines the title and paragraphs.
|
|
|
|
|
|
|
|
This function takes a list of paragraphs and a title, and constructs a single string
|
|
|
|
with the title followed by the paragraphs, formatted for embedding.
|
|
|
|
"""
|
|
|
|
|
|
|
|
return f"passage: {title}\n" + '\n'.join(paragraphs)
|
|
|
|
|
|
|
|
def embed_page(self, filename, url, title, contents, tags, max_chunk_size=512):
|
|
|
|
"""
|
|
|
|
Embeds the text of a webpage into a ChromaDB collection.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
filename (str): The name of the file being processed.
|
|
|
|
url (str): The URL of the webpage.
|
|
|
|
title (str): The title of the webpage.
|
|
|
|
contents (list of str): The contents of the webpage, split into paragraphs.
|
|
|
|
tags (list of str): Tags derived from the URL.
|
|
|
|
max_chunk_size (int): The maximum token length for a chunk. Defaults to 512.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
|
|
|
|
This function divides the webpage content into chunks that fit within the max_chunk_size limit,
|
|
|
|
embeds each chunk using the provided model, and stores the embeddings in the specified ChromaDB collection.
|
|
|
|
"""
|
|
|
|
|
|
|
|
documents = []
|
|
|
|
contents_to_embed = [contents]
|
|
|
|
|
|
|
|
while contents_to_embed:
|
|
|
|
last_item = contents_to_embed.pop()
|
|
|
|
# (1) For the `multilingual-e5-large` embedding model,
|
|
|
|
# the string of a document must be prepended with "passage:"
|
|
|
|
# (2) Since the text of a webpage may have to be cut into many documents,
|
|
|
|
# we always add the title of the webpage at the top of a document
|
|
|
|
last_item_str = self.passage_str(last_item, title)
|
|
|
|
last_item_token_length = self.token_length(last_item_str)
|
|
|
|
|
|
|
|
if last_item_token_length > max_chunk_size:
|
|
|
|
# If the text of the webpage, present in file `filename`,
|
|
|
|
# contains more than `max_chunk_size` tokens, it must be divided
|
|
|
|
# into multiple documents
|
|
|
|
if len(last_item) > 1:
|
|
|
|
# If there are many paragraphs in `last_item`, i.e. the current
|
|
|
|
# part of the webpage for which an embedding will be made,
|
|
|
|
# the length of `last_item` can be reduced by dividing its set of
|
|
|
|
# paragraphs in half
|
|
|
|
h = len(last_item) // 2
|
|
|
|
last_item_h1 = last_item[:h]
|
|
|
|
last_item_h2 = last_item[h:]
|
|
|
|
contents_to_embed.append(last_item_h1)
|
|
|
|
contents_to_embed.append(last_item_h2)
|
|
|
|
else:
|
|
|
|
# If `last_item` is made of only one long paragraph whose length is
|
|
|
|
# larger than `chunk_size`, this paragraph will be divided into two parts.
|
|
|
|
sentences = re.split(r'(?<=[.!?]) +', last_item[0])
|
|
|
|
|
|
|
|
if len(sentences) > 1:
|
|
|
|
# If there are multiple sentences, try to split into two parts
|
|
|
|
i = 1
|
|
|
|
while True:
|
|
|
|
part1 = ' '.join(sentences[:i])
|
|
|
|
part2 = ' '.join(sentences[i:])
|
|
|
|
token_length_part_1 = self.token_length(self.passage_str([part1], title))
|
|
|
|
token_length_part_2 = self.token_length(self.passage_str([part2], title))
|
|
|
|
if (token_length_part_1 <= max_chunk_size and
|
|
|
|
token_length_part_2 <= max_chunk_size) or \
|
|
|
|
token_length_part_1 > max_chunk_size:
|
|
|
|
break
|
|
|
|
i += 1
|
|
|
|
else:
|
|
|
|
# If there's only one long sentence or no suitable split found, split by words
|
|
|
|
words = last_item[0].split()
|
|
|
|
h = len(words) // 2
|
|
|
|
part1 = ' '.join(words[:h])
|
|
|
|
part2 = ' '.join(words[h:])
|
|
|
|
|
|
|
|
contents_to_embed.append([part1])
|
|
|
|
contents_to_embed.append([part2])
|
|
|
|
else:
|
|
|
|
documents.append(last_item_str)
|
|
|
|
|
|
|
|
# We want the documents into which a webpage has been divided
|
|
|
|
# to be in the natural reading order
|
|
|
|
documents.reverse()
|
|
|
|
embeddings = self.model.encode(documents, normalize_embeddings=True)
|
|
|
|
embeddings = embeddings.tolist()
|
|
|
|
|
|
|
|
# We consider the subpart of an URL as tags describing the webpage
|
|
|
|
# For example,
|
|
|
|
# "https://www.caisse-epargne.fr/rhone-alpes/professionnels/financer-projets-optimiser-tresorerie/"
|
|
|
|
# is associated to the tags:
|
|
|
|
# tags[0] == 'rhone-alpes'
|
|
|
|
# tags[1] == 'professionnels'
|
|
|
|
# tags[2] == 'financer-projets-optimiser-tresorerie'
|
|
|
|
if len(tags) < 2:
|
|
|
|
category = ''
|
|
|
|
else:
|
|
|
|
if tags[0] == 'rhone-alpes':
|
|
|
|
category = tags[1]
|
|
|
|
else: category = tags[0]
|
|
|
|
metadata = {'category': category, 'url': url}
|
|
|
|
# All the documents corresponding to a same webpage have the same metadata, i.e. URL and category
|
|
|
|
metadatas = [copy.deepcopy(metadata) for _ in range(len(documents))]
|
|
|
|
|
|
|
|
ids = [filename + '-' + str(i+1) for i in range(len(documents))]
|
|
|
|
|
|
|
|
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
|
|
|
|
|
2024-01-07 14:31:08 -05:00
|
|
|
def remove_duplicate(self, lst):
|
|
|
|
# file_contents can contain duplicate lines
|
|
|
|
# because we keep the textual content of multiple html tags that can be embedded one in another
|
|
|
|
i = 0
|
|
|
|
while i < len(lst) - 1:
|
|
|
|
if i < len(lst) - 3 and lst[i] == lst[i + 2] and lst[i + 1] == lst[i + 3] == '':
|
|
|
|
# Remove lst[i+1], lst[i+2], and lst[i+3]
|
|
|
|
del lst[i + 1:i + 3]
|
|
|
|
elif lst[i] == lst[i + 1]:
|
|
|
|
# Remove lst[i+1]
|
|
|
|
del lst[i + 1]
|
|
|
|
else:
|
|
|
|
i += 1
|
|
|
|
return lst
|
|
|
|
|
|
|
|
def remove_footer(self, lst):
|
|
|
|
sequence = ["Caisse d'Epargne", "Rhône Alpes", "Formuler une demande en ligne"]
|
|
|
|
for i in range(len(lst) - 2):
|
|
|
|
if lst[i:i + 3] == sequence:
|
|
|
|
del lst[i:]
|
|
|
|
break
|
|
|
|
return lst
|
|
|
|
|
2024-01-03 10:24:17 -05:00
|
|
|
def embed_folder(self, folder_path):
|
|
|
|
"""
|
|
|
|
Embeds all the .txt files within a specified folder into a ChromaDB collection using a specified embedding model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
folder_path (str): Path to the folder containing .txt files.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
|
|
|
|
This function processes each .txt file in the given folder, extracts the content, and uses `embed_page`
|
|
|
|
to embed the content into the specified ChromaDB collection.
|
|
|
|
"""
|
|
|
|
|
|
|
|
for filename in os.listdir(folder_path):
|
|
|
|
if filename.endswith('.txt'):
|
|
|
|
file_path = os.path.join(folder_path, filename)
|
|
|
|
with open(file_path, 'r') as file:
|
|
|
|
file_contents = file.read()
|
2024-01-07 14:31:08 -05:00
|
|
|
file_contents = '\n'.join(self.remove_duplicate(file_contents.split('\n')))
|
2024-01-03 10:24:17 -05:00
|
|
|
contents_lst = [str.replace('\n',' ').replace('\xa0', ' ') for str in file_contents.split('\n\n')]
|
2024-01-07 14:31:08 -05:00
|
|
|
contents_lst = self.remove_footer(contents_lst)
|
2024-01-03 10:24:17 -05:00
|
|
|
if len(contents_lst) < 3: # contents_lst[0] is the URL, contents_lst[1] is the title, the rest is the content
|
|
|
|
continue
|
|
|
|
url = contents_lst[0]
|
|
|
|
if '?' in url: # URLs with a '?' corresponds to call to services and have no useful content
|
|
|
|
continue
|
2024-01-07 14:31:08 -05:00
|
|
|
if not url.startswith('https://www.caisse-epargne.fr/rhone-alpes/'):
|
|
|
|
continue
|
2024-01-03 10:24:17 -05:00
|
|
|
title = contents_lst[1]
|
|
|
|
if not title: # when the title is absent (or empty), the page has no interest
|
|
|
|
continue
|
|
|
|
logging.info(f"{filename} : Start")
|
|
|
|
prefix = 'https://www.caisse-epargne.fr/'
|
|
|
|
suffix = url.replace(prefix, '')
|
|
|
|
tags = suffix.split('/')
|
|
|
|
tags = [tag for tag in tags if tag] # remove empty parts
|
|
|
|
self.embed_page(filename, url, title, contents_lst[2:], tags)
|
|
|
|
logging.info(f"{filename} : Done")
|