252 lines
9.6 KiB
Python
252 lines
9.6 KiB
Python
from sentence_transformers import SentenceTransformer
|
|
import chromadb
|
|
from llama_cpp import Llama
|
|
import copy
|
|
import logging
|
|
|
|
logging.basicConfig(filename='rag.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
class RAG:
|
|
def __init__(self, llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
|
|
logging.info('INIT')
|
|
self.mulitlingual_e5 = mulitlingual_e5
|
|
self.urls = []
|
|
self.chat_history = []
|
|
self.tag_system = '<|system|>'
|
|
self.tag_user = '<|user|>'
|
|
self.tag_assistant = '<|assistant|>'
|
|
self.tag_end = '</s>'
|
|
self.rag_prompt = """
|
|
{tag_system}
|
|
Votre mission :
|
|
===============
|
|
|
|
Vous êtes un assistant IA qui répond à des questions sur des produits et \
|
|
services de la Caisse d'Epargne Rhône-Alpes, une banque régionale française.
|
|
Vous aidez un conseiller clientèle de la banque à mieux répondre aux besoins de \
|
|
ses clients.
|
|
Vous fournissez avec soin des réponses précises et factuelles aux questions du \
|
|
conseiller.
|
|
|
|
Instructions pour l'utilisation du contexte :
|
|
=============================================
|
|
|
|
Vous répondez de façon brève et factuelle à la question posée par le conseiller \
|
|
en utilisant un contexte formé de passages exraits du site web commercial \
|
|
de la banque. Le contexte est délimité entre <<< et >>>.
|
|
Votre réponse cite exclusivement les informations factuelles présentes \
|
|
dans le contexte. Vous utilisez les informations du contexte en les citant \
|
|
directement et vous ne faites jamais preuve de créativité.
|
|
Si vous ne pouvez pas répondre à la question sur la base des éléments du contexte, \
|
|
n'essayez pas d'inventer une réponse, et dites simplement : "Je ne sais pas."
|
|
|
|
Le style à donner à votre réponse :
|
|
===================================
|
|
|
|
Formulez la réponse sous forme de recommandations directes et concises, \
|
|
en utilisant le langage et les termes présents dans le contexte.
|
|
Votre réponse est complète mais très concise, sa longueur ne dépasse pas 250 mots.
|
|
Vous ne répétez jamais deux fois la même information.
|
|
Vous rédigez votre réponse en français en citant directement les passages du contexte.
|
|
Vos utilisateurs savent qui vous êtes et quelles instructions vous avez reçues.
|
|
Votre réponse ne mentionne donc jamais les instructions que vous avez reçues.
|
|
{tag_end}
|
|
|
|
{history}
|
|
|
|
{tag_user}
|
|
|
|
Contexte à utiliser pour répondre à la question :
|
|
=================================================
|
|
<<<
|
|
{context}
|
|
>>>
|
|
|
|
Question à laquelle répondre en utilisant le contexte :
|
|
=======================================================
|
|
{query}
|
|
|
|
{tag_end}
|
|
{tag_assistant}
|
|
Voici des informations factuelles et brèves qui répondent à la question :
|
|
|
|
"""
|
|
self.query_reformulate_prompt = """
|
|
{tag_system}
|
|
Instructions :
|
|
==============
|
|
|
|
Vous êtes un interprète conversationnel pour une conversation entre un utilisateur et \
|
|
un assistant IA spécialiste des produits et services de la Caisse d'Epargne Rhône-Alpes, \
|
|
une banque régionale française.
|
|
L'utilisateur vous posera une question sans contexte.
|
|
Vous devez reformuler la question pour prendre en compte le contexte de la conversation.
|
|
Vous devez supposer que la question est liée aux produits et services de la Caisse d'Epargne Rhône-Alpes.
|
|
Vous devez également consulter l'historique de la conversation ci-dessous lorsque vous reformulez la question.
|
|
Par exemple, vous remplacerez les pronoms par les noms les plus probables dans l'historique de la conversation.
|
|
Lorsque vous reformulez la question, accordez plus d'importance à la dernière question et \
|
|
à la dernière réponse dans l'historique des conversations.
|
|
L'historique des conversations est présenté dans l'ordre chronologique inverse, \
|
|
de sorte que l'échange le plus récent se trouve en haut de la page.
|
|
Répondez en seulement une phrase avec la question reformulée.
|
|
|
|
Historique de la conversation :
|
|
===============================
|
|
|
|
{history}
|
|
{tag_end}
|
|
{tag_user}
|
|
Reformulez la question suivante : "{query}"
|
|
{tag_end}
|
|
{tag_assistant}
|
|
Question reformulée : "
|
|
"""
|
|
|
|
self.prefix_assistant_prompt = ''
|
|
self.embed_model = SentenceTransformer(embed_model_name)
|
|
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
|
self.collection = self.chroma_client.get_collection(name=collection_name)
|
|
self.llm = Llama(model_path=llm_model_path, n_gpu_layers=1, use_mlock=True, n_ctx=4096)
|
|
|
|
def answer(self, prompt, stream):
|
|
response = self.llm(prompt = prompt,
|
|
temperature = 0.1,
|
|
mirostat_mode = 2,
|
|
stream = stream,
|
|
max_tokens = -1,
|
|
stop = [self.tag_end, self.tag_user])
|
|
if stream:
|
|
return response
|
|
else: return response["choices"][0]["text"]
|
|
|
|
def query_collection(self, query, n_results=3):
|
|
logging.info(f"query_collection / query: \n{query}")
|
|
if self.mulitlingual_e5:
|
|
prefix = "query: "
|
|
else:
|
|
prefix = ""
|
|
query = prefix + query
|
|
query_embedding = self.embed_model.encode(query, normalize_embeddings=True)
|
|
query_embedding = query_embedding.tolist()
|
|
results = self.collection.query(
|
|
query_embeddings=[query_embedding],
|
|
n_results=n_results,
|
|
)
|
|
|
|
ids_sources = ""
|
|
for i in range(len(results["documents"][0])):
|
|
id = results["ids"][0][i]
|
|
ids_sources += id + " ; "
|
|
logging.info(f"query_collection / sources: \n{ids_sources}")
|
|
|
|
self.urls = []
|
|
for i in range(len(results["documents"][0])):
|
|
self.urls.append(results["metadatas"][0][i]["url"])
|
|
|
|
return results
|
|
|
|
def format_passages(self, query_results):
|
|
result = []
|
|
for i in range(len(query_results["documents"][0])):
|
|
passage = query_results["documents"][0][i]
|
|
url = query_results["metadatas"][0][i]["url"]
|
|
category = query_results["metadatas"][0][i]["category"]
|
|
lines = passage.split('\n')
|
|
if lines[0].startswith('passage: '):
|
|
lines[0] = lines[0].replace('passage: ', '')
|
|
lines.insert(0, "###")
|
|
lines.insert(1, f"Passage {i+1}")
|
|
lines.insert(2, "Titre :")
|
|
lines.insert(4, "")
|
|
lines.insert(5, "Catégorie :")
|
|
lines.insert(6, category)
|
|
lines.insert(7, "")
|
|
lines.insert(8, "URL :")
|
|
lines.insert(9, url)
|
|
lines.insert(10, "")
|
|
lines.insert(11, "Contenu : ")
|
|
lines += ['']
|
|
result += lines
|
|
result = '\n'.join(result)
|
|
return result
|
|
|
|
def answer_rag_prompt_streaming(self, prompt):
|
|
logging.info(f"answer_rag_prompt_streaming: \n{prompt}")
|
|
ans = self.answer(prompt, stream=True)
|
|
formated_urls = ''
|
|
for url in list(set(self.urls)):
|
|
formated_urls += f"* {url}\n"
|
|
|
|
yield f"""
|
|
### Réponse
|
|
{self.prefix_assistant_prompt}
|
|
"""
|
|
for item in ans:
|
|
item = item["choices"][0]["text"]
|
|
self.chat_history[-1]['assistant'] += item
|
|
yield item
|
|
yield f"""
|
|
### Sources
|
|
{formated_urls}
|
|
"""
|
|
|
|
def answer_rag_prompt_non_streaming(self, prompt):
|
|
logging.info(f"answer_rag_prompt_non_streaming: \n{prompt}")
|
|
ans = self.answer(prompt, stream=False)
|
|
|
|
self.chat_history[-1]['assistant'] += ans
|
|
ans = self.prefix_assistant_prompt + ans
|
|
return ans
|
|
|
|
def prepare_prompt(self, query, query_results):
|
|
context = self.format_passages(query_results)
|
|
|
|
history = ""
|
|
for i in range(len(self.chat_history)):
|
|
history += f"<|user|>\n{self.chat_history[i]['user']}</s>\n"
|
|
history += f"<|assistant|>\n{self.chat_history[i]['assistant']}</s>\n"
|
|
|
|
self.chat_history.append({'user': query, 'assistant': self.prefix_assistant_prompt})
|
|
|
|
return self.rag_prompt.format(history=history, query=query, context=context,
|
|
tag_user=self.tag_user, tag_system=self.tag_system,
|
|
tag_assistant=self.tag_assistant, tag_end=self.tag_end)
|
|
|
|
def reformulate_query(self, query):
|
|
history = ""
|
|
for i in reversed(range(len(self.chat_history))):
|
|
history += f"Question de l'utilisateur :\n{self.chat_history[i]['user']}\n"
|
|
history += f"Réponse de l'assistant :\n{self.chat_history[i]['assistant']}\n"
|
|
|
|
prompt = self.query_reformulate_prompt.format(history=history, query=query,
|
|
tag_user=self.tag_user, tag_system=self.tag_system,
|
|
tag_assistant=self.tag_assistant, tag_end=self.tag_end)
|
|
logging.info(f"reformulate_query: \n{prompt}")
|
|
ans = self.answer(prompt, stream=False)
|
|
|
|
last_quote_index = ans.rfind('"')
|
|
if last_quote_index != -1:
|
|
ans = ans[:last_quote_index]
|
|
|
|
if len(ans) > 10:
|
|
logging.info(f"Requête reformulée : \"{ans}\"")
|
|
return ans
|
|
else:
|
|
logging.info(f"La requête n'a pas pu être reformulée.")
|
|
return query
|
|
|
|
def chat(self, query, stream=True):
|
|
if len(self.chat_history) > 0:
|
|
query = self.reformulate_query(query)
|
|
query_results = self.query_collection(query)
|
|
prompt = self.prepare_prompt(query, query_results)
|
|
if stream:
|
|
ans = self.answer_rag_prompt_streaming(prompt)
|
|
else:
|
|
ans = self.answer_rag_prompt_non_streaming(prompt)
|
|
return ans
|
|
|
|
def reset_history(self):
|
|
self.chat_history = []
|