replace llama-cpp-python with llamafile server, change sampling to min-p
This commit is contained in:
parent
8f78f6c656
commit
b07068ed60
5
app.py
5
app.py
@ -1,3 +1,4 @@
|
||||
# python -m streamlit run app.py
|
||||
import streamlit as st
|
||||
from rag import RAG
|
||||
import re
|
||||
@ -5,12 +6,12 @@ import logging
|
||||
|
||||
@st.cache_resource
|
||||
def init_rag():
|
||||
llm_model_path = '/Users/peportier/llm/a/a/zephyr-7b-beta.Q5_K_M.gguf'
|
||||
llm_url = 'http://127.0.0.1:8080/completion'
|
||||
# embed_model_name = 'intfloat/multilingual-e5-large'
|
||||
embed_model_name = 'dangvantuan/sentence-camembert-large'
|
||||
collection_name = 'cera'
|
||||
chromadb_path = './chromadb'
|
||||
rag = RAG(llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=False)
|
||||
rag = RAG(llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=False)
|
||||
return rag
|
||||
|
||||
rag = init_rag()
|
||||
|
57
rag.py
57
rag.py
@ -3,12 +3,14 @@ import chromadb
|
||||
from llama_cpp import Llama
|
||||
import copy
|
||||
import logging
|
||||
import json
|
||||
import requests
|
||||
|
||||
logging.basicConfig(filename='rag.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
class RAG:
|
||||
def __init__(self, llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
|
||||
def __init__(self, llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
|
||||
logging.info('INIT')
|
||||
self.mulitlingual_e5 = mulitlingual_e5
|
||||
self.urls = []
|
||||
@ -107,18 +109,26 @@ Question reformulée : "
|
||||
self.embed_model = SentenceTransformer(embed_model_name)
|
||||
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
||||
self.collection = self.chroma_client.get_collection(name=collection_name)
|
||||
self.llm = Llama(model_path=llm_model_path, n_gpu_layers=1, use_mlock=True, n_ctx=4096)
|
||||
# ./llamafile -n -1 -c 4096 --mlock --gpu APPLE -m a/a/zephyr-7b-beta.Q5_K_M.gguf
|
||||
self.llm_url = llm_url
|
||||
|
||||
def answer(self, prompt, stream):
|
||||
response = self.llm(prompt = prompt,
|
||||
temperature = 0.1,
|
||||
mirostat_mode = 2,
|
||||
stream = stream,
|
||||
max_tokens = -1,
|
||||
stop = [self.tag_end, self.tag_user])
|
||||
if stream:
|
||||
return response
|
||||
else: return response["choices"][0]["text"]
|
||||
|
||||
post_params = {
|
||||
"prompt": prompt,
|
||||
"temp": 0.1,
|
||||
"repeat_penalty": 1.1,
|
||||
"min_p": 0.05,
|
||||
"top_p": 0.5,
|
||||
"top_k": 0,
|
||||
"stop": [self.tag_end, self.tag_user],
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
response = requests.post(self.llm_url, json=post_params, stream=stream)
|
||||
|
||||
if stream: return response
|
||||
else: return response.json()['content']
|
||||
|
||||
def query_collection(self, query, n_results=3):
|
||||
logging.info(f"query_collection / query: \n{query}")
|
||||
@ -145,7 +155,7 @@ Question reformulée : "
|
||||
self.urls.append(results["metadatas"][0][i]["url"])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_passages(self, query_results):
|
||||
result = []
|
||||
for i in range(len(query_results["documents"][0])):
|
||||
@ -182,10 +192,17 @@ Question reformulée : "
|
||||
### Réponse
|
||||
{self.prefix_assistant_prompt}
|
||||
"""
|
||||
for item in ans:
|
||||
item = item["choices"][0]["text"]
|
||||
self.chat_history[-1]['assistant'] += item
|
||||
yield item
|
||||
|
||||
|
||||
for line in ans.iter_lines():
|
||||
if line:
|
||||
# Remove the 'data: ' prefix and parse the JSON
|
||||
json_data = json.loads(line.decode('utf-8')[6:])
|
||||
# Print the 'content' field
|
||||
item = json_data['content']
|
||||
self.chat_history[-1]['assistant'] += item
|
||||
yield item
|
||||
|
||||
yield f"""
|
||||
### Sources
|
||||
{formated_urls}
|
||||
@ -208,11 +225,11 @@ Question reformulée : "
|
||||
history += f"<|assistant|>\n{self.chat_history[i]['assistant']}</s>\n"
|
||||
|
||||
self.chat_history.append({'user': query, 'assistant': self.prefix_assistant_prompt})
|
||||
|
||||
|
||||
return self.rag_prompt.format(history=history, query=query, context=context,
|
||||
tag_user=self.tag_user, tag_system=self.tag_system,
|
||||
tag_assistant=self.tag_assistant, tag_end=self.tag_end)
|
||||
|
||||
|
||||
def reformulate_query(self, query):
|
||||
history = ""
|
||||
for i in reversed(range(len(self.chat_history))):
|
||||
@ -235,7 +252,7 @@ Question reformulée : "
|
||||
else:
|
||||
logging.info(f"La requête n'a pas pu être reformulée.")
|
||||
return query
|
||||
|
||||
|
||||
def chat(self, query, stream=True):
|
||||
if len(self.chat_history) > 0:
|
||||
query = self.reformulate_query(query)
|
||||
@ -246,6 +263,6 @@ Question reformulée : "
|
||||
else:
|
||||
ans = self.answer_rag_prompt_non_streaming(prompt)
|
||||
return ans
|
||||
|
||||
|
||||
def reset_history(self):
|
||||
self.chat_history = []
|
||||
|
Loading…
Reference in New Issue
Block a user