From b07068ed600d8c8eda0fa69ae4dbbe561bbd11b1 Mon Sep 17 00:00:00 2001 From: Pierre-Edouard Portier Date: Wed, 17 Jan 2024 21:38:44 +0100 Subject: [PATCH] replace llama-cpp-python with llamafile server, change sampling to min-p --- app.py | 5 +++-- rag.py | 57 +++++++++++++++++++++++++++++++++++++-------------------- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/app.py b/app.py index 854bd79..0d2d0ea 100644 --- a/app.py +++ b/app.py @@ -1,3 +1,4 @@ +# python -m streamlit run app.py import streamlit as st from rag import RAG import re @@ -5,12 +6,12 @@ import logging @st.cache_resource def init_rag(): - llm_model_path = '/Users/peportier/llm/a/a/zephyr-7b-beta.Q5_K_M.gguf' + llm_url = 'http://127.0.0.1:8080/completion' # embed_model_name = 'intfloat/multilingual-e5-large' embed_model_name = 'dangvantuan/sentence-camembert-large' collection_name = 'cera' chromadb_path = './chromadb' - rag = RAG(llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=False) + rag = RAG(llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=False) return rag rag = init_rag() diff --git a/rag.py b/rag.py index 666f7bb..8a6ec53 100644 --- a/rag.py +++ b/rag.py @@ -3,12 +3,14 @@ import chromadb from llama_cpp import Llama import copy import logging +import json +import requests logging.basicConfig(filename='rag.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') class RAG: - def __init__(self, llm_model_path, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True): + def __init__(self, llm_url, embed_model_name, collection_name, chromadb_path, mulitlingual_e5=True): logging.info('INIT') self.mulitlingual_e5 = mulitlingual_e5 self.urls = [] @@ -107,18 +109,26 @@ Question reformulée : " self.embed_model = SentenceTransformer(embed_model_name) self.chroma_client = chromadb.PersistentClient(path=chromadb_path) self.collection = self.chroma_client.get_collection(name=collection_name) - self.llm = Llama(model_path=llm_model_path, n_gpu_layers=1, use_mlock=True, n_ctx=4096) + # ./llamafile -n -1 -c 4096 --mlock --gpu APPLE -m a/a/zephyr-7b-beta.Q5_K_M.gguf + self.llm_url = llm_url def answer(self, prompt, stream): - response = self.llm(prompt = prompt, - temperature = 0.1, - mirostat_mode = 2, - stream = stream, - max_tokens = -1, - stop = [self.tag_end, self.tag_user]) - if stream: - return response - else: return response["choices"][0]["text"] + + post_params = { + "prompt": prompt, + "temp": 0.1, + "repeat_penalty": 1.1, + "min_p": 0.05, + "top_p": 0.5, + "top_k": 0, + "stop": [self.tag_end, self.tag_user], + "stream": stream + } + + response = requests.post(self.llm_url, json=post_params, stream=stream) + + if stream: return response + else: return response.json()['content'] def query_collection(self, query, n_results=3): logging.info(f"query_collection / query: \n{query}") @@ -145,7 +155,7 @@ Question reformulée : " self.urls.append(results["metadatas"][0][i]["url"]) return results - + def format_passages(self, query_results): result = [] for i in range(len(query_results["documents"][0])): @@ -182,10 +192,17 @@ Question reformulée : " ### Réponse {self.prefix_assistant_prompt} """ - for item in ans: - item = item["choices"][0]["text"] - self.chat_history[-1]['assistant'] += item - yield item + + + for line in ans.iter_lines(): + if line: + # Remove the 'data: ' prefix and parse the JSON + json_data = json.loads(line.decode('utf-8')[6:]) + # Print the 'content' field + item = json_data['content'] + self.chat_history[-1]['assistant'] += item + yield item + yield f""" ### Sources {formated_urls} @@ -208,11 +225,11 @@ Question reformulée : " history += f"<|assistant|>\n{self.chat_history[i]['assistant']}\n" self.chat_history.append({'user': query, 'assistant': self.prefix_assistant_prompt}) - + return self.rag_prompt.format(history=history, query=query, context=context, tag_user=self.tag_user, tag_system=self.tag_system, tag_assistant=self.tag_assistant, tag_end=self.tag_end) - + def reformulate_query(self, query): history = "" for i in reversed(range(len(self.chat_history))): @@ -235,7 +252,7 @@ Question reformulée : " else: logging.info(f"La requête n'a pas pu être reformulée.") return query - + def chat(self, query, stream=True): if len(self.chat_history) > 0: query = self.reformulate_query(query) @@ -246,6 +263,6 @@ Question reformulée : " else: ans = self.answer_rag_prompt_non_streaming(prompt) return ans - + def reset_history(self): self.chat_history = []