replace llama-cpp-python with llamafile server, change sampling to min-p
This commit is contained in:
parent
8f78f6c656
commit
b07068ed60
5
app.py
5
app.py
|
@ -1,3 +1,4 @@
|
||||||
|
# python -m streamlit run app.py
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from rag import RAG
|
from rag import RAG
|
||||||
import re
|
import re
|
||||||
|
@ -5,12 +6,12 @@ import logging
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def init_rag():
|
def init_rag():
|
||||||
llm_model_path = '/Users/peportier/llm/a/a/zephyr-7b-beta.Q5_K_M.gguf'
|
llm_url = 'http://127.0.0.1:8080/completion'
|
||||||
# embed_model_name = 'intfloat/multilingual-e5-large'
|
# embed_model_name = 'intfloat/multilingual-e5-large'
|
||||||
embed_model_name = 'dangvantuan/sentence-camembert-large'
|
embed_model_name = 'dangvantuan/sentence-camembert-large'
|
||||||
collection_name = 'cera'
|
collection_name = 'cera'
|
||||||
chromadb_path = './chromadb'
|
chromadb_path = './chromadb'
|
||||||
rag = RAG(llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=False)
|
rag = RAG(llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=False)
|
||||||
return rag
|
return rag
|
||||||
|
|
||||||
rag = init_rag()
|
rag = init_rag()
|
||||||
|
|
43
rag.py
43
rag.py
|
@ -3,12 +3,14 @@ import chromadb
|
||||||
from llama_cpp import Llama
|
from llama_cpp import Llama
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
logging.basicConfig(filename='rag.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
logging.basicConfig(filename='rag.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
|
||||||
class RAG:
|
class RAG:
|
||||||
def __init__(self, llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
|
def __init__(self, llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True):
|
||||||
logging.info('INIT')
|
logging.info('INIT')
|
||||||
self.mulitlingual_e5 = mulitlingual_e5
|
self.mulitlingual_e5 = mulitlingual_e5
|
||||||
self.urls = []
|
self.urls = []
|
||||||
|
@ -107,18 +109,26 @@ Question reformulée : "
|
||||||
self.embed_model = SentenceTransformer(embed_model_name)
|
self.embed_model = SentenceTransformer(embed_model_name)
|
||||||
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
self.chroma_client = chromadb.PersistentClient(path=chromadb_path)
|
||||||
self.collection = self.chroma_client.get_collection(name=collection_name)
|
self.collection = self.chroma_client.get_collection(name=collection_name)
|
||||||
self.llm = Llama(model_path=llm_model_path, n_gpu_layers=1, use_mlock=True, n_ctx=4096)
|
# ./llamafile -n -1 -c 4096 --mlock --gpu APPLE -m a/a/zephyr-7b-beta.Q5_K_M.gguf
|
||||||
|
self.llm_url = llm_url
|
||||||
|
|
||||||
def answer(self, prompt, stream):
|
def answer(self, prompt, stream):
|
||||||
response = self.llm(prompt = prompt,
|
|
||||||
temperature = 0.1,
|
post_params = {
|
||||||
mirostat_mode = 2,
|
"prompt": prompt,
|
||||||
stream = stream,
|
"temp": 0.1,
|
||||||
max_tokens = -1,
|
"repeat_penalty": 1.1,
|
||||||
stop = [self.tag_end, self.tag_user])
|
"min_p": 0.05,
|
||||||
if stream:
|
"top_p": 0.5,
|
||||||
return response
|
"top_k": 0,
|
||||||
else: return response["choices"][0]["text"]
|
"stop": [self.tag_end, self.tag_user],
|
||||||
|
"stream": stream
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(self.llm_url, json=post_params, stream=stream)
|
||||||
|
|
||||||
|
if stream: return response
|
||||||
|
else: return response.json()['content']
|
||||||
|
|
||||||
def query_collection(self, query, n_results=3):
|
def query_collection(self, query, n_results=3):
|
||||||
logging.info(f"query_collection / query: \n{query}")
|
logging.info(f"query_collection / query: \n{query}")
|
||||||
|
@ -182,10 +192,17 @@ Question reformulée : "
|
||||||
### Réponse
|
### Réponse
|
||||||
{self.prefix_assistant_prompt}
|
{self.prefix_assistant_prompt}
|
||||||
"""
|
"""
|
||||||
for item in ans:
|
|
||||||
item = item["choices"][0]["text"]
|
|
||||||
|
for line in ans.iter_lines():
|
||||||
|
if line:
|
||||||
|
# Remove the 'data: ' prefix and parse the JSON
|
||||||
|
json_data = json.loads(line.decode('utf-8')[6:])
|
||||||
|
# Print the 'content' field
|
||||||
|
item = json_data['content']
|
||||||
self.chat_history[-1]['assistant'] += item
|
self.chat_history[-1]['assistant'] += item
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
yield f"""
|
yield f"""
|
||||||
### Sources
|
### Sources
|
||||||
{formated_urls}
|
{formated_urls}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user