rag/embedding2.py

274 lines
12 KiB
Python
Raw Permalink Normal View History

2024-01-10 14:06:42 +00:00
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import os
import re
import copy
import chromadb
import logging
from bs4 import BeautifulSoup
logging.basicConfig(filename='embedding.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class EmbeddingModel:
2024-01-17 19:50:28 +00:00
def __init__(self, model_name, chromadb_path, collection_name, mulitlingual_e5=True):
self.mulitlingual_e5=mulitlingual_e5
2024-01-10 14:06:42 +00:00
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = SentenceTransformer(model_name)
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
self.collection = self.chroma_client.get_or_create_collection(name=collection_name)
def token_length(self, text):
"""
Calculates the token length of a given text
Args:
text (str): The text to be tokenized.
Returns:
int: The number of tokens in the text.
This function takes a string, tokenizes the string, and returns the number of tokens.
"""
return len(self.tokenizer.encode(text, add_special_tokens=False))
def passage_str(self, paragraphs, title):
"""
Constructs a passage string from given paragraphs and a title.
Args:
paragraphs (list of str): A list of paragraphs.
title (str): The title of the passage.
Returns:
str: A passage string that combines the title and paragraphs.
This function takes a list of paragraphs and a title, and constructs a single string
with the title followed by the paragraphs, formatted for embedding.
"""
return f"passage: {title}\n" + '\n'.join(paragraphs)
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
def embed_folder(self, html_folder_path, txt_folder_path):
"""
Embeds all the .html files within a specified folder into a ChromaDB collection using a specified embedding model.
2024-01-17 19:50:28 +00:00
The txt folder is required to get the URL of the webpage. TODO: change this behavior in a future version.
2024-01-10 14:06:42 +00:00
Args:
html_folder_path (str): Path to the folder containing .html files.
txt_folder_path (str): Path to the folder containing .txt files.
Returns:
None
This function processes each .html file in the given folder, extracts the content, and uses `embed_page`
to embed the content into the specified ChromaDB collection.
"""
for html_filename in os.listdir(html_folder_path):
html_file_path = os.path.join(html_folder_path, html_filename)
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
txt_filename = re.sub(r'\.html', '.txt', html_filename)
txt_file_path = os.path.join(txt_folder_path, txt_filename)
with open(txt_file_path, 'r') as file:
txt_file_contents = file.read()
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
url = txt_file_contents.split('\n')[0]
if '?' in url: # URLs with a '?' corresponds to call to services and have no useful content
continue
2024-01-17 19:50:28 +00:00
if not url.startswith('https://www.caisse-epargne.fr/rhone-alpes/'):
2024-01-10 14:06:42 +00:00
continue
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
prefix = 'https://www.caisse-epargne.fr/'
suffix = url.replace(prefix, '')
tags = suffix.split('/')
tags = [tag for tag in tags if tag] # remove empty parts
with open(html_file_path, 'r') as file:
html_file_contents = file.read()
soup = BeautifulSoup(html_file_contents, 'html.parser')
first_section = soup.find('section')
if not first_section:
continue
page_title_present = first_section.find('h1')
if not page_title_present:
continue
page_title = page_title_present.get_text()
2024-01-17 19:50:28 +00:00
sections = soup.find_all(lambda tag: tag.name in ['section'])
2024-01-10 14:06:42 +00:00
struct_page = {'title': page_title}
current_section = ''
2024-01-17 19:50:28 +00:00
titles = [page_title]
2024-01-10 14:06:42 +00:00
for section in sections:
2024-01-17 19:50:28 +00:00
if 'key-informations' in section.get('class', []):
key_items = []
for key_item in section.find_all('div', class_='container-block'):
key_item_text = ''
for key_item_title in key_item.find_all('div', class_='button'):
key_item_text += key_item_title.get_text().strip()
for key_item_desc in key_item.find_all('div', class_="tab-panel"):
key_item_text += ' ' + key_item_desc.get_text().strip()
if len(key_item_text) > 0:
key_items.append(key_item_text)
if len(key_items) > 0:
struct_page['Les points clés'] = key_items
continue
2024-01-10 14:06:42 +00:00
for wysiwyg_tag in section.find_all(class_="wysiwyg"):
# Check for a title within the wysiwyg container
internal_title = wysiwyg_tag.find(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']) or wysiwyg_tag.find('p', class_='title')
2024-01-17 19:50:28 +00:00
if internal_title:
title_tag = internal_title
title = internal_title.get_text().strip()
title = re.sub(r'\(\d\)', '', title)
title = re.sub(r'^\d+\.\s*', '', title)
titles.append(title)
current_section = title
else: # If no internal title, find the nearest title from previous tags
title_tag = None
current_section = titles[-1]
if current_section not in struct_page:
struct_page[current_section] = []
2024-01-10 14:06:42 +00:00
for child in wysiwyg_tag.find_all(['p', 'li', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']):
2024-01-17 19:50:28 +00:00
if child == title_tag:
continue
if 'is-style-mentions' in child.get('class', []):
continue
2024-01-10 14:06:42 +00:00
text = child.get_text().strip()
text = re.sub(r'\(\d\)', '', text)
2024-01-17 19:50:28 +00:00
struct_page[current_section].append(text)
if len(struct_page[current_section]) == 0:
del struct_page[current_section]
2024-01-10 14:06:42 +00:00
logging.info(f"{html_filename} : Start")
self.embed_page(html_filename, url, struct_page, tags)
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
def token_length(self, text):
"""
Calculates the token length of a given text
Args:
text (str): The text to be tokenized.
Returns:
int: The number of tokens in the text.
This function takes a string, tokenizes the string, and returns the number of tokens.
"""
return len(self.tokenizer.encode(text, add_special_tokens=False))
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
def passage_str(self, paragraphs, title, subtitle):
"""
Constructs a passage string from given paragraphs and a title.
Args:
paragraphs (list of str): A list of paragraphs.
title (str): The title of the webpage.
subtitle (str): The title of the passage.
Returns:
str: A passage string that combines the titles and paragraphs.
2024-01-17 19:50:28 +00:00
This function takes a passage made of a list of paragraphs extracted
from a webpage, the title of the webpage, the subtitle corresponding to
the passage, and constructs a single string with the titles followed by
2024-01-10 14:06:42 +00:00
the paragraphs, formatted for embedding.
"""
2024-01-17 19:50:28 +00:00
if self.mulitlingual_e5:
prefix = "passage: "
else:
prefix = ""
return f"{prefix}{title}\n\n{subtitle}\n\n" + '\n'.join(paragraphs)
def embed_page(self, html_filename, url, struct_page, tags, max_chunk_size=500):
2024-01-10 14:06:42 +00:00
documents = []
title = struct_page['title']
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
for subtitle, paragraphs in struct_page.items():
if subtitle != 'title':
doc_str = self.passage_str(paragraphs, title, subtitle)
doc_token_length = self.token_length(doc_str)
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
if doc_token_length > max_chunk_size:
2024-01-17 19:50:28 +00:00
long_passages = []
2024-01-10 14:06:42 +00:00
sub_paragraphs = []
sub_paragraphs_token_length = 0
paragraph_index = 0
while True:
while sub_paragraphs_token_length < max_chunk_size and paragraph_index < len(paragraphs):
sub_paragraphs.append(paragraphs[paragraph_index])
sub_paragraphs_str = self.passage_str(sub_paragraphs, title, subtitle)
sub_paragraphs_token_length = self.token_length(sub_paragraphs_str)
paragraph_index += 1
2024-01-17 19:50:28 +00:00
if paragraph_index >= len(paragraphs):
2024-01-10 14:06:42 +00:00
if sub_paragraphs_token_length >= max_chunk_size:
sub_paragraphs_str_1 = self.passage_str(sub_paragraphs[:-1], title, subtitle)
sub_paragraphs_str_2 = self.passage_str([sub_paragraphs[-1]], title, subtitle)
documents.append(sub_paragraphs_str_1)
2024-01-17 19:50:28 +00:00
if self.token_length(sub_paragraphs_str_2) < max_chunk_size:
documents.append(sub_paragraphs_str_2)
else:
long_passages.append(sub_paragraphs[0])
2024-01-10 14:06:42 +00:00
else:
2024-01-17 19:50:28 +00:00
documents.append(sub_paragraphs_str)
2024-01-10 14:06:42 +00:00
break
2024-01-17 19:50:28 +00:00
else: # sub_paragraphs_token_length >= max_chunk_size and paragraph_index < len(paragraphs)
if len(sub_paragraphs) > 1:
sub_paragraphs_str = self.passage_str(sub_paragraphs[:-1], title, subtitle)
documents.append(sub_paragraphs_str)
paragraph_index -= 1
else:
long_passages.append(sub_paragraphs[0])
2024-01-10 14:06:42 +00:00
sub_paragraphs = []
sub_paragraphs_token_length = 0
2024-01-17 19:50:28 +00:00
for long_passage in long_passages:
passage = []
for word in long_passage.split():
passage.append(word)
passage_str = self.passage_str([' '.join(passage)], title, subtitle)
if self.token_length(passage_str) > max_chunk_size:
passage_str = self.passage_str([' '.join(passage[:-1])], title, subtitle)
documents.append(passage_str)
passage = [passage[-1]]
passage_str = self.passage_str([' '.join(passage)], title, subtitle)
documents.append(passage_str)
2024-01-10 14:06:42 +00:00
else:
documents.append(doc_str)
2024-01-17 19:50:28 +00:00
2024-01-10 14:06:42 +00:00
if len(documents) == 0:
return
embeddings = self.model.encode(documents, normalize_embeddings=True)
embeddings = embeddings.tolist()
# We consider the subpart of an URL as tags describing the webpage
# For example,
# "https://www.caisse-epargne.fr/rhone-alpes/professionnels/financer-projets-optimiser-tresorerie/"
# is associated to the tags:
# tags[0] == 'rhone-alpes'
# tags[1] == 'professionnels'
# tags[2] == 'financer-projets-optimiser-tresorerie'
if len(tags) < 2:
category = ''
else:
if tags[0] == 'rhone-alpes':
category = tags[1]
else: category = tags[0]
metadata = {'category': category, 'url': url}
# All the documents corresponding to a same webpage have the same metadata, i.e. URL and category
metadatas = [copy.deepcopy(metadata) for _ in range(len(documents))]
ids = [html_filename + '-' + str(i+1) for i in range(len(documents))]
2024-01-17 19:50:28 +00:00
self.collection.add(embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids)