2024-01-03 10:24:58 -05:00
from sentence_transformers import SentenceTransformer
import chromadb
from llama_cpp import Llama
import copy
import logging
import re
logging . basicConfig ( level = logging . INFO , format = ' %(asctime)s - %(levelname)s - %(message)s ' )
class RAG :
def __init__ ( self , llm_model_path , embed_model_name , collection_name , chromadb_path ) :
self . chat_history = [ ]
self . rag_system_prompt = """
Vous êtes un assistant IA qui répond à la question posée par l ' utilisateur en utilisant un contexte répertorié ci-dessous dans la rubrique Contexte.
Le contexte est formé de passages exraits du site web commercial de la Caisse d ' Epargne Rhône-Alpes, une banque française régionale.
Votre réponse ne doit pas mentionner des informations déjà présentes dans l ' historique de la conversation qui est répertorié ci-dessous dans la rubrique Historique.
Vous fournissez avec soin des réponses précises , factuelles , réfléchies et nuancées , et vous êtes doué pour le raisonnement .
Toutes les informations factuelles que vous utilisez pour répondre proviennent exclusivement du contexte .
Si vous ne pouvez pas répondre à la question sur la base des éléments du contexte , dites simplement que vous ne savez pas , n ' essayez pas d ' inventer une réponse .
Vos réponses doivent être brèves .
Vos utilisateurs savent que vos réponses sont brèves et qu ' elles ne mentionnent que les éléments du contexte, il n ' est pas nécessaire de le leur rappeler .
Vous gardez le rôle d ' assistant et vous ne générez jamais le texte ' < | user | > ' .
Vous rédigez vos réponses en français au format markdown sous forme d ' une liste composée de 1 à 7 éléments au maximum.
Voici le format que doit prendre votre réponse :
` ` `
Voici des éléments de réponse :
2024-01-03 15:36:33 -05:00
1. Elément de réponse .
2. Elément de réponse .
3. Elément de réponse .
2024-01-03 10:24:58 -05:00
4. . . .
` ` `
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Historique :
{ history }
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Contexte :
{ context }
"""
self . query_reformulate_system_prompt = """
Vous êtes un interprète conversationnel pour une conversation entre un utilisateur et \
un assistant IA spécialiste des produits et services de la Caisse d ' Epargne Rhône-Alpes, \
une banque régionale française .
L ' utilisateur vous posera une question sans contexte. \
Vous devez reformuler la question pour prendre en compte le contexte de la conversation .
Vous devez supposer que la question est liée aux produits et services de la Caisse d ' Epargne Rhône-Alpes.
Vous devez également consulter l ' historique de la conversation ci-dessous lorsque vous reformulez la question. \
Par exemple , vous remplacerez les pronoms par les noms les plus probables dans l ' historique de la conversation.
Lorsque vous reformulez la question , accordez plus d ' importance à la dernière question et \
à la dernière réponse dans l ' historique des conversations.
L ' historique des conversations est présenté dans l ' ordre chronologique inverse , \
de sorte que l ' échange le plus récent se trouve en haut de la page.
Répondez en seulement une phrase avec la question reformulée .
Historique de la conversation :
"""
self . embed_model = SentenceTransformer ( embed_model_name )
self . chroma_client = chromadb . PersistentClient ( path = chromadb_path )
self . collection = self . chroma_client . get_collection ( name = collection_name )
self . llm = Llama ( model_path = llm_model_path , n_gpu_layers = 1 , use_mlock = True , n_ctx = 4096 )
def answer ( self , prompt , stream = False ) :
response = self . llm ( prompt = prompt ,
temperature = 0.1 ,
mirostat_mode = 2 ,
stream = stream ,
max_tokens = - 1 ,
stop = [ ' </s> ' , ' 8. ' , ' \n \n ' , ' <|user|> ' ] )
if stream :
return response
else : return response [ " choices " ] [ 0 ] [ " text " ]
def query_collection ( self , query , n_results = 3 ) :
query = ' query: ' + query
query_embedding = self . embed_model . encode ( query , normalize_embeddings = True )
query_embedding = query_embedding . tolist ( )
results = self . collection . query (
query_embeddings = [ query_embedding ] ,
n_results = n_results ,
)
return results
def format_passages ( self , query_results ) :
result = [ ]
for i in range ( len ( query_results [ " documents " ] [ 0 ] ) ) :
passage = query_results [ " documents " ] [ 0 ] [ i ]
url = query_results [ " metadatas " ] [ 0 ] [ i ] [ " url " ]
category = query_results [ " metadatas " ] [ 0 ] [ i ] [ " category " ]
lines = passage . split ( ' \n ' )
if lines [ 0 ] . startswith ( ' passage: ' ) :
lines [ 0 ] = lines [ 0 ] . replace ( ' passage: ' , ' ' )
lines . insert ( 0 , " ### " )
lines . insert ( 1 , f " Passage { i + 1 } " )
lines . insert ( 2 , " Titre : " )
lines . insert ( 4 , " " )
lines . insert ( 5 , " Catégorie : " )
lines . insert ( 6 , category )
lines . insert ( 7 , " " )
lines . insert ( 8 , " URL : " )
lines . insert ( 9 , url )
lines . insert ( 10 , " " )
lines . insert ( 11 , " Contenu : " )
lines + = [ ' ' ]
result + = lines
result = ' \n ' . join ( result )
return result
def format_rag_prompt ( self , query , context = " " , history = " " ) :
user_query = f " Question de l ' utilisateur : \n { query } \n \n "
assistant_answer = f " Réponse de l ' assistant : \n 1. "
self . chat_history . append ( { ' user ' : user_query , ' assistant ' : assistant_answer } )
system_prompt = self . rag_system_prompt . format ( history = history , context = context )
prompt = " "
prompt = f " <|system|> \n { system_prompt . strip ( ) } </s> \n "
prompt + = f " <|user|> \n { query } </s> \n "
prompt + = f " <|assistant|> \n Voici des éléments de réponse : \n 1. "
return prompt
def remove_references ( self , text ) :
motif = r " \ (Passage \ d+ \ ) "
res = re . sub ( motif , ' ' , text )
return res
def answer_rag_prompt ( self , query , query_results , stream = False ) :
query_context = self . format_passages ( query_results )
history = " "
for i in reversed ( range ( len ( self . chat_history ) ) ) :
history + = self . chat_history [ i ] [ " user " ]
history + = self . remove_references ( self . chat_history [ i ] [ " assistant " ] )
history + = " \n \n "
prompt = self . format_rag_prompt ( query , query_context , history )
logging . info ( prompt )
ans = self . answer ( prompt , stream )
if stream :
yield ' 1. '
for item in ans :
item = item [ " choices " ] [ 0 ] [ " text " ]
self . chat_history [ - 1 ] [ ' assistant ' ] + = item
yield item
else :
self . chat_history [ - 1 ] [ ' assistant ' ] + = ans
ans = ' 1. ' + ans
return ans
def format_prompt_reformulate_query ( self , query ) :
system_prompt = self . query_reformulate_system_prompt
for i in reversed ( range ( len ( self . chat_history ) ) ) :
system_prompt + = self . chat_history [ i ] [ " user " ]
system_prompt + = self . chat_history [ i ] [ " assistant " ]
prompt = " "
prompt = f " <|system|> \n { system_prompt . strip ( ) } </s> \n "
prompt + = f " <|user|> \n Peux-tu reformuler la question suivante : \n \" { query } \" </s> \n "
prompt + = f " <|assistant|> Question reformulée : \n \" "
return prompt
def reformulate_query ( self , query ) :
prompt = self . format_prompt_reformulate_query ( query )
logging . info ( prompt )
ans = self . answer ( prompt )
last_quote_index = ans . rfind ( ' " ' )
if last_quote_index != - 1 :
ans = ans [ : last_quote_index ]
if len ( ans ) > 10 :
logging . info ( f " Requête reformulée : \" { ans } \" " )
return ans
else :
logging . info ( f " La requête n ' a pas pu être reformulée. " )
return query
def chat ( self , query ) :
if len ( self . chat_history ) > 0 :
query = self . reformulate_query ( query )
query_results = self . query_collection ( query )
ans = self . answer_rag_prompt ( query , query_results , stream = True )
return ans
def reset_history ( self ) :
self . chat_history = [ ]