2024-01-03 16:24:58 +01:00
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
import chromadb
|
|
|
|
from llama_cpp import Llama
|
|
|
|
import copy
|
|
|
|
import logging
|
2024-01-17 21:38:44 +01:00
|
|
|
import json
|
|
|
|
import requests
|
2024-01-21 21:30:01 +01:00
|
|
|
import spacy
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
logging.basicConfig(filename='rag.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
2024-01-03 16:24:58 +01:00
|
|
|
|
|
|
|
class RAG:
|
2024-01-17 21:38:44 +01:00
|
|
|
def __init__(self, llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
|
2024-01-05 13:34:48 +01:00
|
|
|
logging.info('INIT')
|
2024-01-17 20:50:28 +01:00
|
|
|
self.mulitlingual_e5 = mulitlingual_e5
|
|
|
|
self.urls = []
|
2024-01-03 16:24:58 +01:00
|
|
|
self.chat_history = []
|
2024-01-05 13:34:48 +01:00
|
|
|
self.tag_system = '<|system|>'
|
|
|
|
self.tag_user = '<|user|>'
|
|
|
|
self.tag_assistant = '<|assistant|>'
|
|
|
|
self.tag_end = '</s>'
|
|
|
|
self.rag_prompt = """
|
|
|
|
{tag_system}
|
2024-01-10 15:06:42 +01:00
|
|
|
Votre mission :
|
|
|
|
===============
|
|
|
|
|
2024-01-21 21:30:01 +01:00
|
|
|
Vous aidez un conseiller de la Caisse d'Epargne Rhône-Alpes, \
|
|
|
|
une banque française, à répondre aux questions de son client.
|
|
|
|
Ces questions porteront sur des produits et services de la banque.
|
2024-01-18 14:19:18 +01:00
|
|
|
Vous fournissez avec soin des réponses courtes, précises et factuelles aux questions \
|
|
|
|
qui vous sont posées.
|
2024-01-10 15:06:42 +01:00
|
|
|
|
|
|
|
Instructions pour l'utilisation du contexte :
|
|
|
|
=============================================
|
|
|
|
|
2024-01-18 14:19:18 +01:00
|
|
|
Vous répondez de façon brève et factuelle à la question posée \
|
2024-01-21 21:30:01 +01:00
|
|
|
en utilisant un contexte formé de passages exraits du site web \
|
2024-01-10 15:06:42 +01:00
|
|
|
de la banque. Le contexte est délimité entre <<< et >>>.
|
|
|
|
Votre réponse cite exclusivement les informations factuelles présentes \
|
2024-01-18 14:19:18 +01:00
|
|
|
dans le contexte. Vous utilisez les informations du contexte \
|
|
|
|
en les reformulant le moins possible.
|
|
|
|
Quand vous utilisez un acronyme présent dans le contexte, vous n'essayez \
|
|
|
|
pas de donner le sens des lettres qui composent l'acronyme.
|
|
|
|
Vous ne faites jamais preuve de créativité. Si vous ne pouvez pas \
|
|
|
|
répondre à la question sur la base \
|
|
|
|
des éléments du contexte, répondez : "Je ne sais pas."
|
2024-01-05 13:34:48 +01:00
|
|
|
|
2024-01-10 15:06:42 +01:00
|
|
|
Le style à donner à votre réponse :
|
|
|
|
===================================
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-10 15:06:42 +01:00
|
|
|
Formulez la réponse sous forme de recommandations directes et concises, \
|
|
|
|
en utilisant le langage et les termes présents dans le contexte.
|
|
|
|
Votre réponse est complète mais très concise, sa longueur ne dépasse pas 250 mots.
|
2024-01-18 14:19:18 +01:00
|
|
|
Vous rédigez votre réponse en français en réutilisant directement les passages du contexte.
|
2024-01-10 15:06:42 +01:00
|
|
|
Vos utilisateurs savent qui vous êtes et quelles instructions vous avez reçues.
|
|
|
|
Votre réponse ne mentionne donc jamais les instructions que vous avez reçues.
|
2024-01-05 13:34:48 +01:00
|
|
|
{tag_end}
|
2024-01-10 15:06:42 +01:00
|
|
|
|
2024-01-03 16:24:58 +01:00
|
|
|
{history}
|
2024-01-05 13:34:48 +01:00
|
|
|
|
|
|
|
{tag_user}
|
2024-01-10 15:06:42 +01:00
|
|
|
|
|
|
|
Contexte à utiliser pour répondre à la question :
|
|
|
|
=================================================
|
|
|
|
<<<
|
2024-01-03 16:24:58 +01:00
|
|
|
{context}
|
2024-01-10 15:06:42 +01:00
|
|
|
>>>
|
2024-01-05 13:34:48 +01:00
|
|
|
|
2024-01-10 15:06:42 +01:00
|
|
|
Question à laquelle répondre en utilisant le contexte :
|
|
|
|
=======================================================
|
2024-01-05 13:34:48 +01:00
|
|
|
{query}
|
2024-01-10 15:06:42 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
{tag_end}
|
|
|
|
{tag_assistant}
|
2024-01-21 21:30:01 +01:00
|
|
|
Voici des informations factuelles et brèves qui pourront vous aider à \
|
|
|
|
répondre à la question de votre client :
|
2024-01-05 13:34:48 +01:00
|
|
|
|
2024-01-03 16:24:58 +01:00
|
|
|
"""
|
2024-01-05 13:34:48 +01:00
|
|
|
self.query_reformulate_prompt = """
|
|
|
|
{tag_system}
|
|
|
|
Instructions :
|
|
|
|
==============
|
|
|
|
|
2024-01-03 16:24:58 +01:00
|
|
|
Vous êtes un interprète conversationnel pour une conversation entre un utilisateur et \
|
|
|
|
un assistant IA spécialiste des produits et services de la Caisse d'Epargne Rhône-Alpes, \
|
|
|
|
une banque régionale française.
|
2024-01-05 13:34:48 +01:00
|
|
|
L'utilisateur vous posera une question sans contexte.
|
2024-01-03 16:24:58 +01:00
|
|
|
Vous devez reformuler la question pour prendre en compte le contexte de la conversation.
|
|
|
|
Vous devez supposer que la question est liée aux produits et services de la Caisse d'Epargne Rhône-Alpes.
|
2024-01-05 13:34:48 +01:00
|
|
|
Vous devez également consulter l'historique de la conversation ci-dessous lorsque vous reformulez la question.
|
2024-01-03 16:24:58 +01:00
|
|
|
Par exemple, vous remplacerez les pronoms par les noms les plus probables dans l'historique de la conversation.
|
|
|
|
Lorsque vous reformulez la question, accordez plus d'importance à la dernière question et \
|
|
|
|
à la dernière réponse dans l'historique des conversations.
|
|
|
|
L'historique des conversations est présenté dans l'ordre chronologique inverse, \
|
|
|
|
de sorte que l'échange le plus récent se trouve en haut de la page.
|
|
|
|
Répondez en seulement une phrase avec la question reformulée.
|
|
|
|
|
|
|
|
Historique de la conversation :
|
2024-01-05 13:34:48 +01:00
|
|
|
===============================
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
{history}
|
|
|
|
{tag_end}
|
|
|
|
{tag_user}
|
|
|
|
Reformulez la question suivante : "{query}"
|
|
|
|
{tag_end}
|
|
|
|
{tag_assistant}
|
|
|
|
Question reformulée : "
|
2024-01-03 16:24:58 +01:00
|
|
|
"""
|
|
|
|
|
2024-01-10 15:06:42 +01:00
|
|
|
self.prefix_assistant_prompt = ''
|
2024-01-03 16:24:58 +01:00
|
|
|
self.embed_model = SentenceTransformer(embed_model_name)
|
|
|
|
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
|
|
|
self.collection = self.chroma_client.get_collection(name=collection_name)
|
2024-01-17 21:38:44 +01:00
|
|
|
# ./llamafile -n -1 -c 4096 --mlock --gpu APPLE -m a/a/zephyr-7b-beta.Q5_K_M.gguf
|
|
|
|
self.llm_url = llm_url
|
2024-01-21 21:30:01 +01:00
|
|
|
self.nlp = spacy.load('fr_core_news_md')
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
def answer(self, prompt, stream):
|
2024-01-17 21:38:44 +01:00
|
|
|
|
|
|
|
post_params = {
|
|
|
|
"prompt": prompt,
|
2024-01-21 21:30:01 +01:00
|
|
|
"temp": 0.1,
|
2024-01-18 14:19:18 +01:00
|
|
|
"repeat_penalty": 1.2,
|
2024-01-17 21:38:44 +01:00
|
|
|
"min_p": 0.05,
|
|
|
|
"top_p": 0.5,
|
|
|
|
"top_k": 0,
|
|
|
|
"stop": [self.tag_end, self.tag_user],
|
|
|
|
"stream": stream
|
|
|
|
}
|
|
|
|
|
|
|
|
response = requests.post(self.llm_url, json=post_params, stream=stream)
|
|
|
|
|
|
|
|
if stream: return response
|
|
|
|
else: return response.json()['content']
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-21 21:30:01 +01:00
|
|
|
def keep_token(self, tok):
|
|
|
|
return tok.pos_ == 'NOUN' or tok.pos_ == 'VERB' or \
|
|
|
|
tok.pos_ == 'PROPN' or tok.pos_ == 'ADJ'
|
|
|
|
|
|
|
|
def lemmatize(self, str):
|
|
|
|
res = []
|
|
|
|
for tok in self.nlp(str):
|
|
|
|
if self.keep_token(tok):
|
|
|
|
res.append(tok.lemma_)
|
|
|
|
return res
|
|
|
|
|
|
|
|
def len_query_inter_doc(self, query_str, doc_str):
|
|
|
|
query_tok = self.lemmatize(query_str)
|
|
|
|
doc_tok = self.lemmatize(doc_str)
|
|
|
|
return len(set(query_tok) & set(doc_tok))
|
|
|
|
|
|
|
|
def query_collection(self, query, n_results=4):
|
2024-01-05 13:34:48 +01:00
|
|
|
logging.info(f"query_collection / query: \n{query}")
|
2024-01-17 20:50:28 +01:00
|
|
|
if self.mulitlingual_e5:
|
|
|
|
prefix = "query: "
|
|
|
|
else:
|
|
|
|
prefix = ""
|
2024-01-21 21:30:01 +01:00
|
|
|
query_embedding = self.embed_model.encode(prefix + query, normalize_embeddings=True)
|
2024-01-03 16:24:58 +01:00
|
|
|
query_embedding = query_embedding.tolist()
|
|
|
|
results = self.collection.query(
|
|
|
|
query_embeddings=[query_embedding],
|
|
|
|
n_results=n_results,
|
|
|
|
)
|
2024-01-05 13:34:48 +01:00
|
|
|
|
2024-01-21 21:30:01 +01:00
|
|
|
res = {"passage": [], "url": [], "cat": [], "id": [], "nb_query_tok": [], "nb_tok": [], "dist": []}
|
2024-01-05 13:34:48 +01:00
|
|
|
|
2024-01-17 20:50:28 +01:00
|
|
|
for i in range(len(results["documents"][0])):
|
2024-01-21 21:30:01 +01:00
|
|
|
passage = results["documents"][0][i]
|
|
|
|
# compute the passage's number of tokens, apart from the title and subtitle
|
|
|
|
# which correspond to the first two '\n\n'-separated elements
|
|
|
|
nb_tok = len(self.lemmatize('\n\n'.join(passage.split('\n\n')[2:])))
|
|
|
|
# Retain only long-enough passages
|
|
|
|
if nb_tok > 20:
|
|
|
|
res['id'].append(results["ids"][0][i])
|
|
|
|
res['url'].append(results["metadatas"][0][i]["url"])
|
|
|
|
res['cat'].append(results["metadatas"][0][i]["category"])
|
|
|
|
res['passage'].append(passage)
|
|
|
|
res['nb_query_tok'].append(self.len_query_inter_doc(query, passage))
|
|
|
|
res['nb_tok'].append(nb_tok)
|
|
|
|
res['dist'].append(results["distances"][0][i])
|
|
|
|
|
|
|
|
# Sort the best passages by their number of tokens in common with the query
|
|
|
|
# and then by their cosine distance of their embeddings to the query's embedding
|
|
|
|
sorted_res_values = sorted(zip(*res.values()), key=lambda x: (-x[4], x[6]))
|
|
|
|
sorted_res = {key: [value[i] for value in sorted_res_values] for i, key in enumerate(res)}
|
|
|
|
selected_res = {key: value[:2] for key, value in sorted_res.items()}
|
|
|
|
|
|
|
|
ids_str = ""
|
|
|
|
for id in selected_res['id']:
|
|
|
|
ids_str += id + " ; "
|
|
|
|
logging.info(f"query_collection / sources: \n{ids_str}")
|
|
|
|
|
|
|
|
self.urls = selected_res['url']
|
|
|
|
|
|
|
|
return selected_res
|
2024-01-17 21:38:44 +01:00
|
|
|
|
2024-01-03 16:24:58 +01:00
|
|
|
def format_passages(self, query_results):
|
|
|
|
result = []
|
2024-01-21 21:30:01 +01:00
|
|
|
for i in range(len(query_results["passage"])):
|
|
|
|
passage = query_results["passage"][i]
|
|
|
|
url = query_results["url"][i]
|
|
|
|
category = query_results["cat"][i]
|
2024-01-03 16:24:58 +01:00
|
|
|
lines = passage.split('\n')
|
|
|
|
if lines[0].startswith('passage: '):
|
|
|
|
lines[0] = lines[0].replace('passage: ', '')
|
|
|
|
lines.insert(0, "###")
|
|
|
|
lines.insert(1, f"Passage {i+1}")
|
|
|
|
lines.insert(2, "Titre :")
|
|
|
|
lines.insert(4, "")
|
|
|
|
lines.insert(5, "Catégorie :")
|
|
|
|
lines.insert(6, category)
|
|
|
|
lines.insert(7, "")
|
|
|
|
lines.insert(8, "URL :")
|
|
|
|
lines.insert(9, url)
|
|
|
|
lines.insert(10, "")
|
|
|
|
lines.insert(11, "Contenu : ")
|
|
|
|
lines += ['']
|
|
|
|
result += lines
|
|
|
|
result = '\n'.join(result)
|
|
|
|
return result
|
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
def answer_rag_prompt_streaming(self, prompt):
|
|
|
|
logging.info(f"answer_rag_prompt_streaming: \n{prompt}")
|
|
|
|
ans = self.answer(prompt, stream=True)
|
2024-01-17 20:50:28 +01:00
|
|
|
formated_urls = ''
|
|
|
|
for url in list(set(self.urls)):
|
|
|
|
formated_urls += f"* {url}\n"
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-17 20:50:28 +01:00
|
|
|
yield f"""
|
|
|
|
{self.prefix_assistant_prompt}
|
|
|
|
"""
|
2024-01-17 21:38:44 +01:00
|
|
|
|
|
|
|
|
|
|
|
for line in ans.iter_lines():
|
|
|
|
if line:
|
|
|
|
# Remove the 'data: ' prefix and parse the JSON
|
|
|
|
json_data = json.loads(line.decode('utf-8')[6:])
|
|
|
|
# Print the 'content' field
|
|
|
|
item = json_data['content']
|
|
|
|
self.chat_history[-1]['assistant'] += item
|
|
|
|
yield item
|
|
|
|
|
2024-01-17 20:50:28 +01:00
|
|
|
yield f"""
|
|
|
|
### Sources
|
|
|
|
{formated_urls}
|
|
|
|
"""
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
def answer_rag_prompt_non_streaming(self, prompt):
|
|
|
|
logging.info(f"answer_rag_prompt_non_streaming: \n{prompt}")
|
|
|
|
ans = self.answer(prompt, stream=False)
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
self.chat_history[-1]['assistant'] += ans
|
|
|
|
ans = self.prefix_assistant_prompt + ans
|
|
|
|
return ans
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
def prepare_prompt(self, query, query_results):
|
|
|
|
context = self.format_passages(query_results)
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
history = ""
|
|
|
|
for i in range(len(self.chat_history)):
|
|
|
|
history += f"<|user|>\n{self.chat_history[i]['user']}</s>\n"
|
|
|
|
history += f"<|assistant|>\n{self.chat_history[i]['assistant']}</s>\n"
|
2024-01-03 16:24:58 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
self.chat_history.append({'user': query, 'assistant': self.prefix_assistant_prompt})
|
2024-01-17 21:38:44 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
return self.rag_prompt.format(history=history, query=query, context=context,
|
|
|
|
tag_user=self.tag_user, tag_system=self.tag_system,
|
|
|
|
tag_assistant=self.tag_assistant, tag_end=self.tag_end)
|
2024-01-17 21:38:44 +01:00
|
|
|
|
2024-01-03 16:24:58 +01:00
|
|
|
def reformulate_query(self, query):
|
2024-01-05 13:34:48 +01:00
|
|
|
history = ""
|
|
|
|
for i in reversed(range(len(self.chat_history))):
|
|
|
|
history += f"Question de l'utilisateur :\n{self.chat_history[i]['user']}\n"
|
|
|
|
history += f"Réponse de l'assistant :\n{self.chat_history[i]['assistant']}\n"
|
|
|
|
|
|
|
|
prompt = self.query_reformulate_prompt.format(history=history, query=query,
|
|
|
|
tag_user=self.tag_user, tag_system=self.tag_system,
|
|
|
|
tag_assistant=self.tag_assistant, tag_end=self.tag_end)
|
|
|
|
logging.info(f"reformulate_query: \n{prompt}")
|
|
|
|
ans = self.answer(prompt, stream=False)
|
2024-01-03 16:24:58 +01:00
|
|
|
|
|
|
|
last_quote_index = ans.rfind('"')
|
|
|
|
if last_quote_index != -1:
|
|
|
|
ans = ans[:last_quote_index]
|
|
|
|
|
|
|
|
if len(ans) > 10:
|
|
|
|
logging.info(f"Requête reformulée : \"{ans}\"")
|
|
|
|
return ans
|
|
|
|
else:
|
|
|
|
logging.info(f"La requête n'a pas pu être reformulée.")
|
|
|
|
return query
|
2024-01-17 21:38:44 +01:00
|
|
|
|
2024-01-05 13:34:48 +01:00
|
|
|
def chat(self, query, stream=True):
|
2024-01-03 16:24:58 +01:00
|
|
|
if len(self.chat_history) > 0:
|
|
|
|
query = self.reformulate_query(query)
|
|
|
|
query_results = self.query_collection(query)
|
2024-01-05 13:34:48 +01:00
|
|
|
prompt = self.prepare_prompt(query, query_results)
|
|
|
|
if stream:
|
|
|
|
ans = self.answer_rag_prompt_streaming(prompt)
|
|
|
|
else:
|
|
|
|
ans = self.answer_rag_prompt_non_streaming(prompt)
|
2024-01-03 16:24:58 +01:00
|
|
|
return ans
|
2024-01-17 21:38:44 +01:00
|
|
|
|
2024-01-03 16:24:58 +01:00
|
|
|
def reset_history(self):
|
|
|
|
self.chat_history = []
|