274 lines
12 KiB
Python
274 lines
12 KiB
Python
from transformers import AutoTokenizer
|
|
from sentence_transformers import SentenceTransformer
|
|
import os
|
|
import re
|
|
import copy
|
|
import chromadb
|
|
import logging
|
|
from bs4 import BeautifulSoup
|
|
|
|
logging.basicConfig(filename='embedding.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
class EmbeddingModel:
|
|
def __init__(self, model_name, chromadb_path, collection_name, mulitlingual_e5=True):
|
|
self.mulitlingual_e5=mulitlingual_e5
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
self.model = SentenceTransformer(model_name)
|
|
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
|
self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
|
|
|
|
def token_length(self, text):
|
|
"""
|
|
Calculates the token length of a given text
|
|
|
|
Args:
|
|
text (str): The text to be tokenized.
|
|
|
|
Returns:
|
|
int: The number of tokens in the text.
|
|
|
|
This function takes a string, tokenizes the string, and returns the number of tokens.
|
|
"""
|
|
return len(self.tokenizer.encode(text, add_special_tokens=False))
|
|
|
|
def passage_str(self, paragraphs, title):
|
|
"""
|
|
Constructs a passage string from given paragraphs and a title.
|
|
|
|
Args:
|
|
paragraphs (list of str): A list of paragraphs.
|
|
title (str): The title of the passage.
|
|
|
|
Returns:
|
|
str: A passage string that combines the title and paragraphs.
|
|
|
|
This function takes a list of paragraphs and a title, and constructs a single string
|
|
with the title followed by the paragraphs, formatted for embedding.
|
|
"""
|
|
|
|
return f"passage: {title}\n" + '\n'.join(paragraphs)
|
|
|
|
def embed_folder(self, html_folder_path, txt_folder_path):
|
|
"""
|
|
Embeds all the .html files within a specified folder into a ChromaDB collection using a specified embedding model.
|
|
The txt folder is required to get the URL of the webpage. TODO: change this behavior in a future version.
|
|
|
|
Args:
|
|
html_folder_path (str): Path to the folder containing .html files.
|
|
txt_folder_path (str): Path to the folder containing .txt files.
|
|
|
|
Returns:
|
|
None
|
|
|
|
This function processes each .html file in the given folder, extracts the content, and uses `embed_page`
|
|
to embed the content into the specified ChromaDB collection.
|
|
"""
|
|
|
|
for html_filename in os.listdir(html_folder_path):
|
|
html_file_path = os.path.join(html_folder_path, html_filename)
|
|
|
|
txt_filename = re.sub(r'\.html', '.txt', html_filename)
|
|
txt_file_path = os.path.join(txt_folder_path, txt_filename)
|
|
with open(txt_file_path, 'r') as file:
|
|
txt_file_contents = file.read()
|
|
|
|
url = txt_file_contents.split('\n')[0]
|
|
if '?' in url: # URLs with a '?' corresponds to call to services and have no useful content
|
|
continue
|
|
if not url.startswith('https://www.caisse-epargne.fr/rhone-alpes/'):
|
|
continue
|
|
|
|
prefix = 'https://www.caisse-epargne.fr/'
|
|
suffix = url.replace(prefix, '')
|
|
tags = suffix.split('/')
|
|
tags = [tag for tag in tags if tag] # remove empty parts
|
|
|
|
with open(html_file_path, 'r') as file:
|
|
html_file_contents = file.read()
|
|
|
|
soup = BeautifulSoup(html_file_contents, 'html.parser')
|
|
|
|
first_section = soup.find('section')
|
|
if not first_section:
|
|
continue
|
|
page_title_present = first_section.find('h1')
|
|
if not page_title_present:
|
|
continue
|
|
page_title = page_title_present.get_text()
|
|
|
|
sections = soup.find_all(lambda tag: tag.name in ['section'])
|
|
|
|
struct_page = {'title': page_title}
|
|
current_section = ''
|
|
titles = [page_title]
|
|
for section in sections:
|
|
|
|
if 'key-informations' in section.get('class', []):
|
|
key_items = []
|
|
for key_item in section.find_all('div', class_='container-block'):
|
|
key_item_text = ''
|
|
for key_item_title in key_item.find_all('div', class_='button'):
|
|
key_item_text += key_item_title.get_text().strip()
|
|
for key_item_desc in key_item.find_all('div', class_="tab-panel"):
|
|
key_item_text += ' ' + key_item_desc.get_text().strip()
|
|
if len(key_item_text) > 0:
|
|
key_items.append(key_item_text)
|
|
if len(key_items) > 0:
|
|
struct_page['Les points clés'] = key_items
|
|
continue
|
|
|
|
for wysiwyg_tag in section.find_all(class_="wysiwyg"):
|
|
# Check for a title within the wysiwyg container
|
|
internal_title = wysiwyg_tag.find(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']) or wysiwyg_tag.find('p', class_='title')
|
|
|
|
if internal_title:
|
|
title_tag = internal_title
|
|
title = internal_title.get_text().strip()
|
|
title = re.sub(r'\(\d\)', '', title)
|
|
title = re.sub(r'^\d+\.\s*', '', title)
|
|
titles.append(title)
|
|
current_section = title
|
|
else: # If no internal title, find the nearest title from previous tags
|
|
title_tag = None
|
|
current_section = titles[-1]
|
|
|
|
if current_section not in struct_page:
|
|
struct_page[current_section] = []
|
|
|
|
for child in wysiwyg_tag.find_all(['p', 'li', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']):
|
|
if child == title_tag:
|
|
continue
|
|
if 'is-style-mentions' in child.get('class', []):
|
|
continue
|
|
text = child.get_text().strip()
|
|
text = re.sub(r'\(\d\)', '', text)
|
|
struct_page[current_section].append(text)
|
|
|
|
if len(struct_page[current_section]) == 0:
|
|
del struct_page[current_section]
|
|
|
|
logging.info(f"{html_filename} : Start")
|
|
self.embed_page(html_filename, url, struct_page, tags)
|
|
|
|
def token_length(self, text):
|
|
"""
|
|
Calculates the token length of a given text
|
|
|
|
Args:
|
|
text (str): The text to be tokenized.
|
|
|
|
Returns:
|
|
int: The number of tokens in the text.
|
|
|
|
This function takes a string, tokenizes the string, and returns the number of tokens.
|
|
"""
|
|
return len(self.tokenizer.encode(text, add_special_tokens=False))
|
|
|
|
def passage_str(self, paragraphs, title, subtitle):
|
|
"""
|
|
Constructs a passage string from given paragraphs and a title.
|
|
|
|
Args:
|
|
paragraphs (list of str): A list of paragraphs.
|
|
title (str): The title of the webpage.
|
|
subtitle (str): The title of the passage.
|
|
|
|
Returns:
|
|
str: A passage string that combines the titles and paragraphs.
|
|
|
|
This function takes a passage made of a list of paragraphs extracted
|
|
from a webpage, the title of the webpage, the subtitle corresponding to
|
|
the passage, and constructs a single string with the titles followed by
|
|
the paragraphs, formatted for embedding.
|
|
"""
|
|
|
|
if self.mulitlingual_e5:
|
|
prefix = "passage: "
|
|
else:
|
|
prefix = ""
|
|
return f"{prefix}{title}\n\n{subtitle}\n\n" + '\n'.join(paragraphs)
|
|
|
|
def embed_page(self, html_filename, url, struct_page, tags, max_chunk_size=500):
|
|
documents = []
|
|
title = struct_page['title']
|
|
|
|
for subtitle, paragraphs in struct_page.items():
|
|
if subtitle != 'title':
|
|
doc_str = self.passage_str(paragraphs, title, subtitle)
|
|
doc_token_length = self.token_length(doc_str)
|
|
|
|
if doc_token_length > max_chunk_size:
|
|
|
|
long_passages = []
|
|
sub_paragraphs = []
|
|
sub_paragraphs_token_length = 0
|
|
paragraph_index = 0
|
|
while True:
|
|
while sub_paragraphs_token_length < max_chunk_size and paragraph_index < len(paragraphs):
|
|
sub_paragraphs.append(paragraphs[paragraph_index])
|
|
sub_paragraphs_str = self.passage_str(sub_paragraphs, title, subtitle)
|
|
sub_paragraphs_token_length = self.token_length(sub_paragraphs_str)
|
|
paragraph_index += 1
|
|
if paragraph_index >= len(paragraphs):
|
|
if sub_paragraphs_token_length >= max_chunk_size:
|
|
sub_paragraphs_str_1 = self.passage_str(sub_paragraphs[:-1], title, subtitle)
|
|
sub_paragraphs_str_2 = self.passage_str([sub_paragraphs[-1]], title, subtitle)
|
|
documents.append(sub_paragraphs_str_1)
|
|
if self.token_length(sub_paragraphs_str_2) < max_chunk_size:
|
|
documents.append(sub_paragraphs_str_2)
|
|
else:
|
|
long_passages.append(sub_paragraphs[0])
|
|
else:
|
|
documents.append(sub_paragraphs_str)
|
|
break
|
|
else: # sub_paragraphs_token_length >= max_chunk_size and paragraph_index < len(paragraphs)
|
|
if len(sub_paragraphs) > 1:
|
|
sub_paragraphs_str = self.passage_str(sub_paragraphs[:-1], title, subtitle)
|
|
documents.append(sub_paragraphs_str)
|
|
paragraph_index -= 1
|
|
else:
|
|
long_passages.append(sub_paragraphs[0])
|
|
sub_paragraphs = []
|
|
sub_paragraphs_token_length = 0
|
|
for long_passage in long_passages:
|
|
passage = []
|
|
for word in long_passage.split():
|
|
passage.append(word)
|
|
passage_str = self.passage_str([' '.join(passage)], title, subtitle)
|
|
if self.token_length(passage_str) > max_chunk_size:
|
|
passage_str = self.passage_str([' '.join(passage[:-1])], title, subtitle)
|
|
documents.append(passage_str)
|
|
passage = [passage[-1]]
|
|
passage_str = self.passage_str([' '.join(passage)], title, subtitle)
|
|
documents.append(passage_str)
|
|
|
|
else:
|
|
documents.append(doc_str)
|
|
|
|
if len(documents) == 0:
|
|
return
|
|
|
|
embeddings = self.model.encode(documents, normalize_embeddings=True)
|
|
embeddings = embeddings.tolist()
|
|
|
|
# We consider the subpart of an URL as tags describing the webpage
|
|
# For example,
|
|
# "https://www.caisse-epargne.fr/rhone-alpes/professionnels/financer-projets-optimiser-tresorerie/"
|
|
# is associated to the tags:
|
|
# tags[0] == 'rhone-alpes'
|
|
# tags[1] == 'professionnels'
|
|
# tags[2] == 'financer-projets-optimiser-tresorerie'
|
|
if len(tags) < 2:
|
|
category = ''
|
|
else:
|
|
if tags[0] == 'rhone-alpes':
|
|
category = tags[1]
|
|
else: category = tags[0]
|
|
metadata = {'category': category, 'url': url}
|
|
# All the documents corresponding to a same webpage have the same metadata, i.e. URL and category
|
|
metadatas = [copy.deepcopy(metadata) for _ in range(len(documents))]
|
|
|
|
ids = [html_filename + '-' + str(i+1) for i in range(len(documents))]
|
|
|
|
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)
|