rag/rag.py

307 lines
12 KiB
Python

from sentence_transformers import SentenceTransformer
import chromadb
from llama_cpp import Llama
import copy
import logging
import json
import requests
import spacy
logging.basicConfig(filename='rag.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class RAG:
def __init__(self, llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
logging.info('INIT')
self.mulitlingual_e5 = mulitlingual_e5
self.urls = []
self.chat_history = []
self.tag_system = '<|system|>'
self.tag_user = '<|user|>'
self.tag_assistant = '<|assistant|>'
self.tag_end = '</s>'
self.rag_prompt = """
{tag_system}
Votre mission :
===============
Vous aidez un conseiller de la Caisse d'Epargne Rhône-Alpes, \
une banque française, à répondre aux questions de son client.
Ces questions porteront sur des produits et services de la banque.
Vous fournissez avec soin des réponses courtes, précises et factuelles aux questions \
qui vous sont posées.
Instructions pour l'utilisation du contexte :
=============================================
Vous répondez de façon brève et factuelle à la question posée \
en utilisant un contexte formé de passages exraits du site web \
de la banque. Le contexte est délimité entre <<< et >>>.
Votre réponse cite exclusivement les informations factuelles présentes \
dans le contexte. Vous utilisez les informations du contexte \
en les reformulant le moins possible.
Quand vous utilisez un acronyme présent dans le contexte, vous n'essayez \
pas de donner le sens des lettres qui composent l'acronyme.
Vous ne faites jamais preuve de créativité. Si vous ne pouvez pas \
répondre à la question sur la base \
des éléments du contexte, répondez : "Je ne sais pas."
Le style à donner à votre réponse :
===================================
Formulez la réponse sous forme de recommandations directes et concises, \
en utilisant le langage et les termes présents dans le contexte.
Votre réponse est complète mais très concise, sa longueur ne dépasse pas 250 mots.
Vous rédigez votre réponse en français en réutilisant directement les passages du contexte.
Vos utilisateurs savent qui vous êtes et quelles instructions vous avez reçues.
Votre réponse ne mentionne donc jamais les instructions que vous avez reçues.
{tag_end}
{history}
{tag_user}
Contexte à utiliser pour répondre à la question :
=================================================
<<<
{context}
>>>
Question à laquelle répondre en utilisant le contexte :
=======================================================
{query}
{tag_end}
{tag_assistant}
Voici des informations factuelles et brèves qui pourront vous aider à \
répondre à la question de votre client :
"""
self.query_reformulate_prompt = """
{tag_system}
Instructions :
==============
Vous êtes un interprète conversationnel pour une conversation entre un utilisateur et \
un assistant IA spécialiste des produits et services de la Caisse d'Epargne Rhône-Alpes, \
une banque régionale française.
L'utilisateur vous posera une question sans contexte.
Vous devez reformuler la question pour prendre en compte le contexte de la conversation.
Vous devez supposer que la question est liée aux produits et services de la Caisse d'Epargne Rhône-Alpes.
Vous devez également consulter l'historique de la conversation ci-dessous lorsque vous reformulez la question.
Par exemple, vous remplacerez les pronoms par les noms les plus probables dans l'historique de la conversation.
Lorsque vous reformulez la question, accordez plus d'importance à la dernière question et \
à la dernière réponse dans l'historique des conversations.
L'historique des conversations est présenté dans l'ordre chronologique inverse, \
de sorte que l'échange le plus récent se trouve en haut de la page.
Répondez en seulement une phrase avec la question reformulée.
Historique de la conversation :
===============================
{history}
{tag_end}
{tag_user}
Reformulez la question suivante : "{query}"
{tag_end}
{tag_assistant}
Question reformulée : "
"""
self.prefix_assistant_prompt = ''
self.embed_model = SentenceTransformer(embed_model_name)
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
self.collection = self.chroma_client.get_collection(name=collection_name)
# ./llamafile -n -1 -c 4096 --mlock --gpu APPLE -m a/a/zephyr-7b-beta.Q5_K_M.gguf
self.llm_url = llm_url
self.nlp = spacy.load('fr_core_news_md')
def answer(self, prompt, stream):
post_params = {
"prompt": prompt,
"temp": 0.1,
"repeat_penalty": 1.2,
"min_p": 0.05,
"top_p": 0.5,
"top_k": 0,
"stop": [self.tag_end, self.tag_user],
"stream": stream
}
response = requests.post(self.llm_url, json=post_params, stream=stream)
if stream: return response
else: return response.json()['content']
def keep_token(self, tok):
return tok.pos_ == 'NOUN' or tok.pos_ == 'VERB' or \
tok.pos_ == 'PROPN' or tok.pos_ == 'ADJ'
def lemmatize(self, str):
res = []
for tok in self.nlp(str):
if self.keep_token(tok):
res.append(tok.lemma_)
return res
def len_query_inter_doc(self, query_str, doc_str):
query_tok = self.lemmatize(query_str)
doc_tok = self.lemmatize(doc_str)
return len(set(query_tok) & set(doc_tok))
def query_collection(self, query, n_results=4):
logging.info(f"query_collection / query: \n{query}")
if self.mulitlingual_e5:
prefix = "query: "
else:
prefix = ""
query_embedding = self.embed_model.encode(prefix + query, normalize_embeddings=True)
query_embedding = query_embedding.tolist()
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
)
res = {"passage": [], "url": [], "cat": [], "id": [], "nb_query_tok": [], "nb_tok": [], "dist": []}
for i in range(len(results["documents"][0])):
passage = results["documents"][0][i]
# compute the passage's number of tokens, apart from the title and subtitle
# which correspond to the first two '\n\n'-separated elements
nb_tok = len(self.lemmatize('\n\n'.join(passage.split('\n\n')[2:])))
# Retain only long-enough passages
if nb_tok > 20:
res['id'].append(results["ids"][0][i])
res['url'].append(results["metadatas"][0][i]["url"])
res['cat'].append(results["metadatas"][0][i]["category"])
res['passage'].append(passage)
res['nb_query_tok'].append(self.len_query_inter_doc(query, passage))
res['nb_tok'].append(nb_tok)
res['dist'].append(results["distances"][0][i])
# Sort the best passages by their number of tokens in common with the query
# and then by their cosine distance of their embeddings to the query's embedding
sorted_res_values = sorted(zip(*res.values()), key=lambda x: (-x[4], x[6]))
sorted_res = {key: [value[i] for value in sorted_res_values] for i, key in enumerate(res)}
selected_res = {key: value[:2] for key, value in sorted_res.items()}
ids_str = ""
for id in selected_res['id']:
ids_str += id + " ; "
logging.info(f"query_collection / sources: \n{ids_str}")
self.urls = selected_res['url']
return selected_res
def format_passages(self, query_results):
result = []
for i in range(len(query_results["passage"])):
passage = query_results["passage"][i]
url = query_results["url"][i]
category = query_results["cat"][i]
lines = passage.split('\n')
if lines[0].startswith('passage: '):
lines[0] = lines[0].replace('passage: ', '')
lines.insert(0, "###")
lines.insert(1, f"Passage {i+1}")
lines.insert(2, "Titre :")
lines.insert(4, "")
lines.insert(5, "Catégorie :")
lines.insert(6, category)
lines.insert(7, "")
lines.insert(8, "URL :")
lines.insert(9, url)
lines.insert(10, "")
lines.insert(11, "Contenu : ")
lines += ['']
result += lines
result = '\n'.join(result)
return result
def answer_rag_prompt_streaming(self, prompt):
logging.info(f"answer_rag_prompt_streaming: \n{prompt}")
ans = self.answer(prompt, stream=True)
formated_urls = ''
for url in list(set(self.urls)):
formated_urls += f"* {url}\n"
yield f"""
{self.prefix_assistant_prompt}
"""
for line in ans.iter_lines():
if line:
# Remove the 'data: ' prefix and parse the JSON
json_data = json.loads(line.decode('utf-8')[6:])
# Print the 'content' field
item = json_data['content']
self.chat_history[-1]['assistant'] += item
yield item
yield f"""
### Sources
{formated_urls}
"""
def answer_rag_prompt_non_streaming(self, prompt):
logging.info(f"answer_rag_prompt_non_streaming: \n{prompt}")
ans = self.answer(prompt, stream=False)
self.chat_history[-1]['assistant'] += ans
ans = self.prefix_assistant_prompt + ans
return ans
def prepare_prompt(self, query, query_results):
context = self.format_passages(query_results)
history = ""
for i in range(len(self.chat_history)):
history += f"<|user|>\n{self.chat_history[i]['user']}</s>\n"
history += f"<|assistant|>\n{self.chat_history[i]['assistant']}</s>\n"
self.chat_history.append({'user': query, 'assistant': self.prefix_assistant_prompt})
return self.rag_prompt.format(history=history, query=query, context=context,
tag_user=self.tag_user, tag_system=self.tag_system,
tag_assistant=self.tag_assistant, tag_end=self.tag_end)
def reformulate_query(self, query):
history = ""
for i in reversed(range(len(self.chat_history))):
history += f"Question de l'utilisateur :\n{self.chat_history[i]['user']}\n"
history += f"Réponse de l'assistant :\n{self.chat_history[i]['assistant']}\n"
prompt = self.query_reformulate_prompt.format(history=history, query=query,
tag_user=self.tag_user, tag_system=self.tag_system,
tag_assistant=self.tag_assistant, tag_end=self.tag_end)
logging.info(f"reformulate_query: \n{prompt}")
ans = self.answer(prompt, stream=False)
last_quote_index = ans.rfind('"')
if last_quote_index != -1:
ans = ans[:last_quote_index]
if len(ans) > 10:
logging.info(f"Requête reformulée : \"{ans}\"")
return ans
else:
logging.info(f"La requête n'a pas pu être reformulée.")
return query
def chat(self, query, stream=True):
if len(self.chat_history) > 0:
query = self.reformulate_query(query)
query_results = self.query_collection(query)
prompt = self.prepare_prompt(query, query_results)
if stream:
ans = self.answer_rag_prompt_streaming(prompt)
else:
ans = self.answer_rag_prompt_non_streaming(prompt)
return ans
def reset_history(self):
self.chat_history = []