From 56e9cde86e210d644e745e1d6336f91c6bf907f9 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 7 May 2023 17:32:10 +0800 Subject: [PATCH] fetch top3 similar answer --- pilot/configs/model_config.py | 2 +- pilot/conversation.py | 2 +- pilot/model/vicuna_llm.py | 32 +++++++++++-------------- pilot/server/embdserver.py | 44 +++++++++++++++++++++++------------ pilot/server/vectordb_qa.py | 25 +++++++++----------- pilot/server/webserver.py | 8 ++++--- 6 files changed, 61 insertions(+), 52 deletions(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 149ddb296..df0318e2d 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -22,7 +22,7 @@ LLM_MODEL_CONFIG = { } -VECTOR_SEARCH_TOP_K = 5 +VECTOR_SEARCH_TOP_K = 3 LLM_MODEL = "vicuna-13b" LIMIT_MODEL_CONCURRENCY = 5 MAX_POSITION_EMBEDDINGS = 2048 diff --git a/pilot/conversation.py b/pilot/conversation.py index 688a5c70d..88d5ca591 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -147,7 +147,7 @@ conv_vicuna_v1 = Conversation( ) -conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。 +conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议: 已知内容: diff --git a/pilot/model/vicuna_llm.py b/pilot/model/vicuna_llm.py index f17a17a00..2337a3bbf 100644 --- a/pilot/model/vicuna_llm.py +++ b/pilot/model/vicuna_llm.py @@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List from langchain.llms.base import LLM from pilot.configs.model_config import * -class VicunaRequestLLM(LLM): +class VicunaLLM(LLM): + + vicuna_generate_path = "generate_stream" + def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str: - vicuna_generate_path = "generate" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - if isinstance(stop, list): - stop = stop + ["Observation:"] - - skip_echo_len = len(prompt.replace("", " ")) + 1 params = { "prompt": prompt, - "temperature": 0.7, - "max_new_tokens": 1024, + "temperature": temperature, + "max_new_tokens": max_new_tokens, "stop": stop } response = requests.post( url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), data=json.dumps(params), ) - response.raise_for_status() - # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): - # if chunk: - # data = json.loads(chunk.decode()) - # if data["error_code"] == 0: - # output = data["text"][skip_echo_len:].strip() - # output = self.post_process_code(output) - # yield output - return response.json()["response"] + + skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 + for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode()) + if data["error_code"] == 0: + output = data["text"][skip_echo_len:].strip() + yield output @property def _llm_type(self) -> str: diff --git a/pilot/server/embdserver.py b/pilot/server/embdserver.py index 5e0ad9294..6599a18ad 100644 --- a/pilot/server/embdserver.py +++ b/pilot/server/embdserver.py @@ -4,29 +4,44 @@ import requests import json import time +import uuid from urllib.parse import urljoin import gradio as gr from pilot.configs.model_config import * -vicuna_base_uri = "http://192.168.31.114:21002/" -vicuna_stream_path = "worker_generate_stream" -vicuna_status_path = "worker_get_status" +from pilot.conversation import conv_qa_prompt_template, conv_templates +from langchain.prompts import PromptTemplate -def generate(prompt): +vicuna_stream_path = "generate_stream" + +def generate(query): + + template_name = "conv_one_shot" + state = conv_templates[template_name].copy() + + pt = PromptTemplate( + template=conv_qa_prompt_template, + input_variables=["context", "question"] + ) + + result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.", + question=query) + + print(result) + + state.append_message(state.roles[0], result) + state.append_message(state.roles[1], None) + + prompt = state.get_prompt() params = { "model": "vicuna-13b", "prompt": prompt, "temperature": 0.7, - "max_new_tokens": 512, + "max_new_tokens": 1024, "stop": "###" } - sts_response = requests.post( - url=urljoin(vicuna_base_uri, vicuna_status_path) - ) - print(sts_response.text) - response = requests.post( - url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params) + url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params) ) skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 @@ -34,11 +49,10 @@ def generate(prompt): if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: - output = data["text"] + output = data["text"][skip_echo_len:].strip() + state.messages[-1][-1] = output + "▌" yield(output) - - time.sleep(0.02) - + if __name__ == "__main__": print(LLM_MODEL) with gr.Blocks() as demo: diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py index 083ce20cd..f0b33f346 100644 --- a/pilot/server/vectordb_qa.py +++ b/pilot/server/vectordb_qa.py @@ -3,31 +3,28 @@ from pilot.vector_store.file_loader import KnownLedge2Vector from langchain.prompts import PromptTemplate -from pilot.conversation import conv_qk_prompt_template +from pilot.conversation import conv_qa_prompt_template from langchain.chains import RetrievalQA from pilot.configs.model_config import VECTOR_SEARCH_TOP_K +from pilot.model.vicuna_llm import VicunaLLM class KnownLedgeBaseQA: - llm: object = None - def __init__(self) -> None: k2v = KnownLedge2Vector() self.vector_store = k2v.init_vector_store() + self.llm = VicunaLLM() - def get_answer(self, query): - prompt_template = conv_qk_prompt_template + def get_similar_answer(self, query): prompt = PromptTemplate( - template=prompt_template, + template=conv_qa_prompt_template, input_variables=["context", "question"] ) - knownledge_chain = RetrievalQA.from_llm( - llm=self.llm, - retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}), - prompt=prompt - ) - knownledge_chain.return_source_documents = True - result = knownledge_chain({"query": query}) - yield result + retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K}) + docs = retriever.get_relevant_documents(query=query) + + context = [d.page_content for d in docs] + result = prompt.format(context="\n".join(context), question=query) + return result diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index dd6753a3b..9d77aa50b 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -170,7 +170,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques query = state.messages[-2][1] - # prompt 中添加上下文提示 + # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文? + # 如果用户侧的问题跨度很大, 应该每一轮都加提示。 if db_selector: new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) new_state.append_message(new_state.roles[1], None) @@ -179,7 +180,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques new_state.append_message(new_state.roles[0], query) new_state.append_message(new_state.roles[1], None) state = new_state - + # try: # if not db_selector: # sim_q = get_simlar(query) @@ -222,7 +223,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return - time.sleep(0.02) + except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) @@ -231,6 +232,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + # 记录运行日志 finish_tstamp = time.time() logger.info(f"{output}")