rag/rag.py

195 lines
8.0 KiB
Python
Raw Normal View History

2024-01-03 15:24:58 +00:00
from sentence_transformers import SentenceTransformer
import chromadb
from llama_cpp import Llama
import copy
import logging
import re
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class RAG:
def __init__(self, llm_model_path, embed_model_name, collection_name, chromadb_path):
self.chat_history = []
self.rag_system_prompt = """
Vous êtes un assistant IA qui répond à la question posée par l'utilisateur en utilisant un contexte répertorié ci-dessous dans la rubrique Contexte.
Le contexte est formé de passages exraits du site web commercial de la Caisse d'Epargne Rhône-Alpes, une banque française régionale.
Votre réponse ne doit pas mentionner des informations déjà présentes dans l'historique de la conversation qui est répertorié ci-dessous dans la rubrique Historique.
Vous fournissez avec soin des réponses précises, factuelles, réfléchies et nuancées, et vous êtes doué pour le raisonnement.
Toutes les informations factuelles que vous utilisez pour répondre proviennent exclusivement du contexte.
Si vous ne pouvez pas répondre à la question sur la base des éléments du contexte, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse.
Vos réponses doivent être brèves.
Vos utilisateurs savent que vos réponses sont brèves et qu'elles ne mentionnent que les éléments du contexte, il n'est pas nécessaire de le leur rappeler.
Vous gardez le rôle d'assistant et vous ne générez jamais le texte '<|user|>'.
Vous rédigez vos réponses en français au format markdown sous forme d'une liste composée de 1 à 7 éléments au maximum.
Voici le format que doit prendre votre réponse :
```
Voici des éléments de réponse :
2024-01-03 20:36:33 +00:00
1. Elément de réponse.
2. Elément de réponse.
3. Elément de réponse.
2024-01-03 15:24:58 +00:00
4. ...
```
----------------------------------------
Historique :
{history}
----------------------------------------
Contexte :
{context}
"""
self.query_reformulate_system_prompt = """
Vous êtes un interprète conversationnel pour une conversation entre un utilisateur et \
un assistant IA spécialiste des produits et services de la Caisse d'Epargne Rhône-Alpes, \
une banque régionale française.
L'utilisateur vous posera une question sans contexte. \
Vous devez reformuler la question pour prendre en compte le contexte de la conversation.
Vous devez supposer que la question est liée aux produits et services de la Caisse d'Epargne Rhône-Alpes.
Vous devez également consulter l'historique de la conversation ci-dessous lorsque vous reformulez la question. \
Par exemple, vous remplacerez les pronoms par les noms les plus probables dans l'historique de la conversation.
Lorsque vous reformulez la question, accordez plus d'importance à la dernière question et \
à la dernière réponse dans l'historique des conversations.
L'historique des conversations est présenté dans l'ordre chronologique inverse, \
de sorte que l'échange le plus récent se trouve en haut de la page.
Répondez en seulement une phrase avec la question reformulée.
Historique de la conversation :
"""
self.embed_model = SentenceTransformer(embed_model_name)
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
self.collection = self.chroma_client.get_collection(name=collection_name)
self.llm = Llama(model_path=llm_model_path, n_gpu_layers=1, use_mlock=True, n_ctx=4096)
def answer(self, prompt, stream=False):
response = self.llm(prompt = prompt,
temperature = 0.1,
mirostat_mode = 2,
stream = stream,
max_tokens = -1,
stop = ['</s>', ' 8.', '\n\n', '<|user|>'])
if stream:
return response
else: return response["choices"][0]["text"]
def query_collection(self, query, n_results=3):
query = 'query: ' + query
query_embedding = self.embed_model.encode(query, normalize_embeddings=True)
query_embedding = query_embedding.tolist()
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
)
return results
def format_passages(self, query_results):
result = []
for i in range(len(query_results["documents"][0])):
passage = query_results["documents"][0][i]
url = query_results["metadatas"][0][i]["url"]
category = query_results["metadatas"][0][i]["category"]
lines = passage.split('\n')
if lines[0].startswith('passage: '):
lines[0] = lines[0].replace('passage: ', '')
lines.insert(0, "###")
lines.insert(1, f"Passage {i+1}")
lines.insert(2, "Titre :")
lines.insert(4, "")
lines.insert(5, "Catégorie :")
lines.insert(6, category)
lines.insert(7, "")
lines.insert(8, "URL :")
lines.insert(9, url)
lines.insert(10, "")
lines.insert(11, "Contenu : ")
lines += ['']
result += lines
result = '\n'.join(result)
return result
def format_rag_prompt(self, query, context="", history=""):
user_query = f"Question de l'utilisateur : \n{query}\n\n"
assistant_answer = f"Réponse de l'assistant : \n 1. "
self.chat_history.append({'user': user_query, 'assistant': assistant_answer})
system_prompt = self.rag_system_prompt.format(history=history, context=context)
prompt = ""
prompt = f"<|system|>\n{system_prompt.strip()}</s>\n"
prompt += f"<|user|>\n{query}</s>\n"
prompt += f"<|assistant|>\n Voici des éléments de réponse : \n 1. "
return prompt
def remove_references(self, text):
motif = r"\(Passage \d+\)"
res = re.sub(motif, '', text)
return res
def answer_rag_prompt(self, query, query_results, stream=False):
query_context = self.format_passages(query_results)
history = ""
for i in reversed(range(len(self.chat_history))):
history += self.chat_history[i]["user"]
history += self.remove_references(self.chat_history[i]["assistant"])
history += "\n\n"
prompt = self.format_rag_prompt(query, query_context, history)
logging.info(prompt)
ans = self.answer(prompt, stream)
if stream:
yield ' 1. '
for item in ans:
item = item["choices"][0]["text"]
self.chat_history[-1]['assistant'] += item
yield item
else:
self.chat_history[-1]['assistant'] += ans
ans = '1. ' + ans
return ans
def format_prompt_reformulate_query(self, query):
system_prompt = self.query_reformulate_system_prompt
for i in reversed(range(len(self.chat_history))):
system_prompt += self.chat_history[i]["user"]
system_prompt += self.chat_history[i]["assistant"]
prompt = ""
prompt = f"<|system|>\n{system_prompt.strip()}</s>\n"
prompt += f"<|user|>\nPeux-tu reformuler la question suivante : \n \"{query}\"</s>\n"
prompt += f"<|assistant|> Question reformulée : \n\""
return prompt
def reformulate_query(self, query):
prompt = self.format_prompt_reformulate_query(query)
logging.info(prompt)
ans = self.answer(prompt)
last_quote_index = ans.rfind('"')
if last_quote_index != -1:
ans = ans[:last_quote_index]
if len(ans) > 10:
logging.info(f"Requête reformulée : \"{ans}\"")
return ans
else:
logging.info(f"La requête n'a pas pu être reformulée.")
return query
def chat(self, query):
if len(self.chat_history) > 0:
query = self.reformulate_query(query)
query_results = self.query_collection(query)
ans = self.answer_rag_prompt(query, query_results, stream=True)
return ans
def reset_history(self):
self.chat_history = []