rag/embedding2.py

238 lines
11 KiB
Python

from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import os
import re
import copy
import chromadb
import logging
from bs4 import BeautifulSoup
logging.basicConfig(filename='embedding.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class EmbeddingModel:
def __init__(self, model_name, chromadb_path, collection_name):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = SentenceTransformer(model_name)
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
def token_length(self, text):
"""
Calculates the token length of a given text
Args:
text (str): The text to be tokenized.
Returns:
int: The number of tokens in the text.
This function takes a string, tokenizes the string, and returns the number of tokens.
"""
return len(self.tokenizer.encode(text, add_special_tokens=False))
def passage_str(self, paragraphs, title):
"""
Constructs a passage string from given paragraphs and a title.
Args:
paragraphs (list of str): A list of paragraphs.
title (str): The title of the passage.
Returns:
str: A passage string that combines the title and paragraphs.
This function takes a list of paragraphs and a title, and constructs a single string
with the title followed by the paragraphs, formatted for embedding.
"""
return f"passage: {title}\n" + '\n'.join(paragraphs)
def embed_folder(self, html_folder_path, txt_folder_path):
"""
Embeds all the .html files within a specified folder into a ChromaDB collection using a specified embedding model.
The txt folder is required to get the URL of the webpage. TODO: change this behavior in a future version.
Args:
html_folder_path (str): Path to the folder containing .html files.
txt_folder_path (str): Path to the folder containing .txt files.
Returns:
None
This function processes each .html file in the given folder, extracts the content, and uses `embed_page`
to embed the content into the specified ChromaDB collection.
"""
for html_filename in os.listdir(html_folder_path):
html_file_path = os.path.join(html_folder_path, html_filename)
txt_filename = re.sub(r'\.html', '.txt', html_filename)
txt_file_path = os.path.join(txt_folder_path, txt_filename)
with open(txt_file_path, 'r') as file:
txt_file_contents = file.read()
url = txt_file_contents.split('\n')[0]
if '?' in url: # URLs with a '?' corresponds to call to services and have no useful content
continue
if not url.startswith('https://www.caisse-epargne.fr/'):
continue
prefix = 'https://www.caisse-epargne.fr/'
suffix = url.replace(prefix, '')
tags = suffix.split('/')
tags = [tag for tag in tags if tag] # remove empty parts
with open(html_file_path, 'r') as file:
html_file_contents = file.read()
soup = BeautifulSoup(html_file_contents, 'html.parser')
first_section = soup.find('section')
if not first_section:
continue
page_title_present = first_section.find('h1')
if not page_title_present:
continue
page_title = page_title_present.get_text()
sections = soup.find_all(lambda tag: tag.name in ['section'] and 'key-informations' not in tag.get('class', []))
struct_page = {'title': page_title}
current_section = ''
for section in sections:
for wysiwyg_tag in section.find_all(class_="wysiwyg"):
# Check for a title within the wysiwyg container
internal_title = wysiwyg_tag.find(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']) or wysiwyg_tag.find('p', class_='title')
# If no internal title, find the nearest title from previous tags
if not internal_title:
# Find the nearest title from previous tags
nearest_title = None
for previous in wysiwyg_tag.find_all_previous():
if previous.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
nearest_title = previous.get_text().strip()
break
if previous.name == 'p' and 'title' in previous.get('class', []):
nearest_title = previous.get_text().strip()
break
if nearest_title:
nearest_title = re.sub(r'\(\d\)', '', nearest_title)
nearest_title = re.sub(r'^\d+\.\s*', '', nearest_title)
current_section = nearest_title
struct_page[current_section] = []
else:
continue
for child in wysiwyg_tag.find_all(['p', 'li', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']):
text = child.get_text().strip()
text = re.sub(r'\(\d\)', '', text)
if child.name.startswith('h') or (child.name == 'p' and 'title' in child.get('class', [])):
text = re.sub(r'^\d+\.\s*', '', text)
current_section = text
struct_page[current_section] = []
else: # <p> not of class title, or <li>
if 'is-style-mentions' not in child.get('class', []):
if current_section in struct_page:
struct_page[current_section].append(text)
logging.info(f"{html_filename} : Start")
self.embed_page(html_filename, url, struct_page, tags)
def token_length(self, text):
"""
Calculates the token length of a given text
Args:
text (str): The text to be tokenized.
Returns:
int: The number of tokens in the text.
This function takes a string, tokenizes the string, and returns the number of tokens.
"""
return len(self.tokenizer.encode(text, add_special_tokens=False))
def passage_str(self, paragraphs, title, subtitle):
"""
Constructs a passage string from given paragraphs and a title.
Args:
paragraphs (list of str): A list of paragraphs.
title (str): The title of the webpage.
subtitle (str): The title of the passage.
Returns:
str: A passage string that combines the titles and paragraphs.
This function takes a passage made of a list of paragraphs extracted
from a webpage, the title of the webpage, the subtitle corresponding to
the passage, and constructs a single string with the titles followed by
the paragraphs, formatted for embedding.
"""
return f"passage: {title}\n\n{subtitle}\n\n" + '\n'.join(paragraphs)
def embed_page(self, html_filename, url, struct_page, tags, max_chunk_size=512):
documents = []
title = struct_page['title']
for subtitle, paragraphs in struct_page.items():
if subtitle != 'title':
doc_str = self.passage_str(paragraphs, title, subtitle)
doc_token_length = self.token_length(doc_str)
if doc_token_length > max_chunk_size:
sub_paragraphs = []
sub_paragraphs_token_length = 0
paragraph_index = 0
while True:
while sub_paragraphs_token_length < max_chunk_size and paragraph_index < len(paragraphs):
sub_paragraphs.append(paragraphs[paragraph_index])
sub_paragraphs_str = self.passage_str(sub_paragraphs, title, subtitle)
sub_paragraphs_token_length = self.token_length(sub_paragraphs_str)
paragraph_index += 1
if paragraph_index == len(paragraphs):
if sub_paragraphs_token_length >= max_chunk_size:
sub_paragraphs_str_1 = self.passage_str(sub_paragraphs[:-1], title, subtitle)
sub_paragraphs_str_2 = self.passage_str([sub_paragraphs[-1]], title, subtitle)
documents.append(sub_paragraphs_str_1)
documents.append(sub_paragraphs_str_2)
else:
sub_paragraphs_str = self.passage_str(sub_paragraphs, title, subtitle)
documents.append(sub_paragraphs_str)
break
else:
sub_paragraphs_str = self.passage_str(sub_paragraphs[:-1], title, subtitle)
documents.append(sub_paragraphs_str)
paragraph_index -= 1
sub_paragraphs = []
sub_paragraphs_token_length = 0
else:
documents.append(doc_str)
if len(documents) == 0:
return
embeddings = self.model.encode(documents, normalize_embeddings=True)
embeddings = embeddings.tolist()
# We consider the subpart of an URL as tags describing the webpage
# For example,
# "https://www.caisse-epargne.fr/rhone-alpes/professionnels/financer-projets-optimiser-tresorerie/"
# is associated to the tags:
# tags[0] == 'rhone-alpes'
# tags[1] == 'professionnels'
# tags[2] == 'financer-projets-optimiser-tresorerie'
if len(tags) < 2:
category = ''
else:
if tags[0] == 'rhone-alpes':
category = tags[1]
else: category = tags[0]
metadata = {'category': category, 'url': url}
# All the documents corresponding to a same webpage have the same metadata, i.e. URL and category
metadatas = [copy.deepcopy(metadata) for _ in range(len(documents))]
ids = [html_filename + '-' + str(i+1) for i in range(len(documents))]
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)