replace llama-cpp-python with llamafile server, change sampling to min-p

This commit is contained in:
Pierre-Edouard Portier 2024-01-17 21:38:44 +01:00
parent 8f78f6c656
commit b07068ed60
2 changed files with 40 additions and 22 deletions

5
app.py
View File

@ -1,3 +1,4 @@
# python -m streamlit run app.py
import streamlit as st
from rag import RAG
import re
@ -5,12 +6,12 @@ import logging
@st.cache_resource
def init_rag():
llm_model_path = '/Users/peportier/llm/a/a/zephyr-7b-beta.Q5_K_M.gguf'
llm_url = 'http://127.0.0.1:8080/completion'
# embed_model_name = 'intfloat/multilingual-e5-large'
embed_model_name = 'dangvantuan/sentence-camembert-large'
collection_name = 'cera'
chromadb_path = './chromadb'
rag = RAG(llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=False)
rag = RAG(llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=False)
return rag
rag = init_rag()

57
rag.py
View File

@ -3,12 +3,14 @@ import chromadb
from llama_cpp import Llama
import copy
import logging
import json
import requests
logging.basicConfig(filename='rag.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class RAG:
def __init__(self, llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
def __init__(self, llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
logging.info('INIT')
self.mulitlingual_e5 = mulitlingual_e5
self.urls = []
@ -107,18 +109,26 @@ Question reformulée : "
self.embed_model = SentenceTransformer(embed_model_name)
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
self.collection = self.chroma_client.get_collection(name=collection_name)
self.llm = Llama(model_path=llm_model_path, n_gpu_layers=1, use_mlock=True, n_ctx=4096)
# ./llamafile -n -1 -c 4096 --mlock --gpu APPLE -m a/a/zephyr-7b-beta.Q5_K_M.gguf
self.llm_url = llm_url
def answer(self, prompt, stream):
response = self.llm(prompt = prompt,
temperature = 0.1,
mirostat_mode = 2,
stream = stream,
max_tokens = -1,
stop = [self.tag_end, self.tag_user])
if stream:
return response
else: return response["choices"][0]["text"]
post_params = {
"prompt": prompt,
"temp": 0.1,
"repeat_penalty": 1.1,
"min_p": 0.05,
"top_p": 0.5,
"top_k": 0,
"stop": [self.tag_end, self.tag_user],
"stream": stream
}
response = requests.post(self.llm_url, json=post_params, stream=stream)
if stream: return response
else: return response.json()['content']
def query_collection(self, query, n_results=3):
logging.info(f"query_collection / query: \n{query}")
@ -145,7 +155,7 @@ Question reformulée : "
self.urls.append(results["metadatas"][0][i]["url"])
return results
def format_passages(self, query_results):
result = []
for i in range(len(query_results["documents"][0])):
@ -182,10 +192,17 @@ Question reformulée : "
### Réponse
{self.prefix_assistant_prompt}
"""
for item in ans:
item = item["choices"][0]["text"]
self.chat_history[-1]['assistant'] += item
yield item
for line in ans.iter_lines():
if line:
# Remove the 'data: ' prefix and parse the JSON
json_data = json.loads(line.decode('utf-8')[6:])
# Print the 'content' field
item = json_data['content']
self.chat_history[-1]['assistant'] += item
yield item
yield f"""
### Sources
{formated_urls}
@ -208,11 +225,11 @@ Question reformulée : "
history += f"<|assistant|>\n{self.chat_history[i]['assistant']}</s>\n"
self.chat_history.append({'user': query, 'assistant': self.prefix_assistant_prompt})
return self.rag_prompt.format(history=history, query=query, context=context,
tag_user=self.tag_user, tag_system=self.tag_system,
tag_assistant=self.tag_assistant, tag_end=self.tag_end)
def reformulate_query(self, query):
history = ""
for i in reversed(range(len(self.chat_history))):
@ -235,7 +252,7 @@ Question reformulée : "
else:
logging.info(f"La requête n'a pas pu être reformulée.")
return query
def chat(self, query, stream=True):
if len(self.chat_history) > 0:
query = self.reformulate_query(query)
@ -246,6 +263,6 @@ Question reformulée : "
else:
ans = self.answer_rag_prompt_non_streaming(prompt)
return ans
def reset_history(self):
self.chat_history = []