195 lines
8.1 KiB
Python
195 lines
8.1 KiB
Python
from sentence_transformers import SentenceTransformer
|
|
import chromadb
|
|
from llama_cpp import Llama
|
|
import copy
|
|
import logging
|
|
import re
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
class RAG:
|
|
def __init__(self, llm_model_path, embed_model_name, collection_name, chromadb_path):
|
|
self.chat_history = []
|
|
self.rag_system_prompt = """
|
|
Vous êtes un assistant IA qui répond à la question posée par l'utilisateur en utilisant un contexte répertorié ci-dessous dans la rubrique Contexte.
|
|
Le contexte est formé de passages exraits du site web commercial de la Caisse d'Epargne Rhône-Alpes, une banque française régionale.
|
|
Votre réponse ne doit pas mentionner des informations déjà présentes dans l'historique de la conversation qui est répertorié ci-dessous dans la rubrique Historique.
|
|
Vous fournissez avec soin des réponses précises, factuelles, réfléchies et nuancées, et vous êtes doué pour le raisonnement.
|
|
Toutes les informations factuelles que vous utilisez pour répondre proviennent exclusivement du contexte.
|
|
Si vous ne pouvez pas répondre à la question sur la base des éléments du contexte, dites simplement que vous ne savez pas, n'essayez pas d'inventer une réponse.
|
|
Vos réponses doivent être brèves.
|
|
Vos utilisateurs savent que vos réponses sont brèves et qu'elles ne mentionnent que les éléments du contexte, il n'est pas nécessaire de le leur rappeler.
|
|
Vous gardez le rôle d'assistant et vous ne générez jamais le texte '<|user|>'.
|
|
Vous rédigez vos réponses en français au format markdown sous forme d'une liste composée de 1 à 7 éléments au maximum.
|
|
Voici le format que doit prendre votre réponse :
|
|
```
|
|
Voici des éléments de réponse :
|
|
1. Elément de réponse. (Passage 1)
|
|
2. Elément de réponse. (Passage 1)
|
|
3. Elément de réponse. (Passage 2)
|
|
4. ...
|
|
```
|
|
|
|
----------------------------------------
|
|
Historique :
|
|
{history}
|
|
----------------------------------------
|
|
Contexte :
|
|
{context}
|
|
"""
|
|
self.query_reformulate_system_prompt = """
|
|
Vous êtes un interprète conversationnel pour une conversation entre un utilisateur et \
|
|
un assistant IA spécialiste des produits et services de la Caisse d'Epargne Rhône-Alpes, \
|
|
une banque régionale française.
|
|
L'utilisateur vous posera une question sans contexte. \
|
|
Vous devez reformuler la question pour prendre en compte le contexte de la conversation.
|
|
Vous devez supposer que la question est liée aux produits et services de la Caisse d'Epargne Rhône-Alpes.
|
|
Vous devez également consulter l'historique de la conversation ci-dessous lorsque vous reformulez la question. \
|
|
Par exemple, vous remplacerez les pronoms par les noms les plus probables dans l'historique de la conversation.
|
|
Lorsque vous reformulez la question, accordez plus d'importance à la dernière question et \
|
|
à la dernière réponse dans l'historique des conversations.
|
|
L'historique des conversations est présenté dans l'ordre chronologique inverse, \
|
|
de sorte que l'échange le plus récent se trouve en haut de la page.
|
|
Répondez en seulement une phrase avec la question reformulée.
|
|
|
|
Historique de la conversation :
|
|
|
|
"""
|
|
|
|
self.embed_model = SentenceTransformer(embed_model_name)
|
|
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
|
self.collection = self.chroma_client.get_collection(name=collection_name)
|
|
self.llm = Llama(model_path=llm_model_path, n_gpu_layers=1, use_mlock=True, n_ctx=4096)
|
|
|
|
def answer(self, prompt, stream=False):
|
|
response = self.llm(prompt = prompt,
|
|
temperature = 0.1,
|
|
mirostat_mode = 2,
|
|
stream = stream,
|
|
max_tokens = -1,
|
|
stop = ['</s>', ' 8.', '\n\n', '<|user|>'])
|
|
if stream:
|
|
return response
|
|
else: return response["choices"][0]["text"]
|
|
|
|
def query_collection(self, query, n_results=3):
|
|
query = 'query: ' + query
|
|
query_embedding = self.embed_model.encode(query, normalize_embeddings=True)
|
|
query_embedding = query_embedding.tolist()
|
|
results = self.collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_results,
|
|
)
|
|
return results
|
|
|
|
def format_passages(self, query_results):
|
|
result = []
|
|
for i in range(len(query_results["documents"][0])):
|
|
passage = query_results["documents"][0][i]
|
|
url = query_results["metadatas"][0][i]["url"]
|
|
category = query_results["metadatas"][0][i]["category"]
|
|
lines = passage.split('\n')
|
|
if lines[0].startswith('passage: '):
|
|
lines[0] = lines[0].replace('passage: ', '')
|
|
lines.insert(0, "###")
|
|
lines.insert(1, f"Passage {i+1}")
|
|
lines.insert(2, "Titre :")
|
|
lines.insert(4, "")
|
|
lines.insert(5, "Catégorie :")
|
|
lines.insert(6, category)
|
|
lines.insert(7, "")
|
|
lines.insert(8, "URL :")
|
|
lines.insert(9, url)
|
|
lines.insert(10, "")
|
|
lines.insert(11, "Contenu : ")
|
|
lines += ['']
|
|
result += lines
|
|
result = '\n'.join(result)
|
|
return result
|
|
|
|
def format_rag_prompt(self, query, context="", history=""):
|
|
|
|
user_query = f"Question de l'utilisateur : \n{query}\n\n"
|
|
assistant_answer = f"Réponse de l'assistant : \n 1. "
|
|
self.chat_history.append({'user': user_query, 'assistant': assistant_answer})
|
|
|
|
system_prompt = self.rag_system_prompt.format(history=history, context=context)
|
|
|
|
prompt = ""
|
|
prompt = f"<|system|>\n{system_prompt.strip()}</s>\n"
|
|
prompt += f"<|user|>\n{query}</s>\n"
|
|
prompt += f"<|assistant|>\n Voici des éléments de réponse : \n 1. "
|
|
|
|
return prompt
|
|
|
|
def remove_references(self, text):
|
|
motif = r"\(Passage \d+\)"
|
|
res = re.sub(motif, '', text)
|
|
return res
|
|
|
|
def answer_rag_prompt(self, query, query_results, stream=False):
|
|
|
|
query_context = self.format_passages(query_results)
|
|
|
|
history = ""
|
|
for i in reversed(range(len(self.chat_history))):
|
|
history += self.chat_history[i]["user"]
|
|
history += self.remove_references(self.chat_history[i]["assistant"])
|
|
history += "\n\n"
|
|
|
|
prompt = self.format_rag_prompt(query, query_context, history)
|
|
|
|
logging.info(prompt)
|
|
|
|
ans = self.answer(prompt, stream)
|
|
if stream:
|
|
yield ' 1. '
|
|
for item in ans:
|
|
item = item["choices"][0]["text"]
|
|
self.chat_history[-1]['assistant'] += item
|
|
yield item
|
|
else:
|
|
self.chat_history[-1]['assistant'] += ans
|
|
ans = '1. ' + ans
|
|
return ans
|
|
|
|
def format_prompt_reformulate_query(self, query):
|
|
system_prompt = self.query_reformulate_system_prompt
|
|
|
|
for i in reversed(range(len(self.chat_history))):
|
|
system_prompt += self.chat_history[i]["user"]
|
|
system_prompt += self.chat_history[i]["assistant"]
|
|
|
|
prompt = ""
|
|
prompt = f"<|system|>\n{system_prompt.strip()}</s>\n"
|
|
prompt += f"<|user|>\nPeux-tu reformuler la question suivante : \n \"{query}\"</s>\n"
|
|
prompt += f"<|assistant|> Question reformulée : \n\""
|
|
|
|
return prompt
|
|
|
|
def reformulate_query(self, query):
|
|
prompt = self.format_prompt_reformulate_query(query)
|
|
logging.info(prompt)
|
|
ans = self.answer(prompt)
|
|
|
|
last_quote_index = ans.rfind('"')
|
|
if last_quote_index != -1:
|
|
ans = ans[:last_quote_index]
|
|
|
|
if len(ans) > 10:
|
|
logging.info(f"Requête reformulée : \"{ans}\"")
|
|
return ans
|
|
else:
|
|
logging.info(f"La requête n'a pas pu être reformulée.")
|
|
return query
|
|
|
|
def chat(self, query):
|
|
if len(self.chat_history) > 0:
|
|
query = self.reformulate_query(query)
|
|
query_results = self.query_collection(query)
|
|
ans = self.answer_rag_prompt(query, query_results, stream=True)
|
|
return ans
|
|
|
|
def reset_history(self):
|
|
self.chat_history = []
|