fetch top3 similar answer

This commit is contained in:
csunny 2023-05-07 17:32:10 +08:00
parent e4899ff7dd
commit 56e9cde86e
6 changed files with 61 additions and 52 deletions

View File

@ -22,7 +22,7 @@ LLM_MODEL_CONFIG = {
} }
VECTOR_SEARCH_TOP_K = 5 VECTOR_SEARCH_TOP_K = 3
LLM_MODEL = "vicuna-13b" LLM_MODEL = "vicuna-13b"
LIMIT_MODEL_CONCURRENCY = 5 LIMIT_MODEL_CONCURRENCY = 5
MAX_POSITION_EMBEDDINGS = 2048 MAX_POSITION_EMBEDDINGS = 2048

View File

@ -147,7 +147,7 @@ conv_vicuna_v1 = Conversation(
) )
conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。 conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议: 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
已知内容: 已知内容:

View File

@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List
from langchain.llms.base import LLM from langchain.llms.base import LLM
from pilot.configs.model_config import * from pilot.configs.model_config import *
class VicunaRequestLLM(LLM): class VicunaLLM(LLM):
vicuna_generate_path = "generate" vicuna_generate_path = "generate_stream"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
if isinstance(stop, list):
stop = stop + ["Observation:"]
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
params = { params = {
"prompt": prompt, "prompt": prompt,
"temperature": 0.7, "temperature": temperature,
"max_new_tokens": 1024, "max_new_tokens": max_new_tokens,
"stop": stop "stop": stop
} }
response = requests.post( response = requests.post(
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path), url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
data=json.dumps(params), data=json.dumps(params),
) )
response.raise_for_status()
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
# if chunk: for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
# data = json.loads(chunk.decode()) if chunk:
# if data["error_code"] == 0: data = json.loads(chunk.decode())
# output = data["text"][skip_echo_len:].strip() if data["error_code"] == 0:
# output = self.post_process_code(output) output = data["text"][skip_echo_len:].strip()
# yield output yield output
return response.json()["response"]
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:

View File

@ -4,29 +4,44 @@
import requests import requests
import json import json
import time import time
import uuid
from urllib.parse import urljoin from urllib.parse import urljoin
import gradio as gr import gradio as gr
from pilot.configs.model_config import * from pilot.configs.model_config import *
vicuna_base_uri = "http://192.168.31.114:21002/" from pilot.conversation import conv_qa_prompt_template, conv_templates
vicuna_stream_path = "worker_generate_stream" from langchain.prompts import PromptTemplate
vicuna_status_path = "worker_get_status"
def generate(prompt): vicuna_stream_path = "generate_stream"
def generate(query):
template_name = "conv_one_shot"
state = conv_templates[template_name].copy()
pt = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
question=query)
print(result)
state.append_message(state.roles[0], result)
state.append_message(state.roles[1], None)
prompt = state.get_prompt()
params = { params = {
"model": "vicuna-13b", "model": "vicuna-13b",
"prompt": prompt, "prompt": prompt,
"temperature": 0.7, "temperature": 0.7,
"max_new_tokens": 512, "max_new_tokens": 1024,
"stop": "###" "stop": "###"
} }
sts_response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_status_path)
)
print(sts_response.text)
response = requests.post( response = requests.post(
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params) url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
) )
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3 skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
@ -34,11 +49,10 @@ def generate(prompt):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
if data["error_code"] == 0: if data["error_code"] == 0:
output = data["text"] output = data["text"][skip_echo_len:].strip()
state.messages[-1][-1] = output + ""
yield(output) yield(output)
time.sleep(0.02)
if __name__ == "__main__": if __name__ == "__main__":
print(LLM_MODEL) print(LLM_MODEL)
with gr.Blocks() as demo: with gr.Blocks() as demo:

View File

@ -3,31 +3,28 @@
from pilot.vector_store.file_loader import KnownLedge2Vector from pilot.vector_store.file_loader import KnownLedge2Vector
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from pilot.conversation import conv_qk_prompt_template from pilot.conversation import conv_qa_prompt_template
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.model.vicuna_llm import VicunaLLM
class KnownLedgeBaseQA: class KnownLedgeBaseQA:
llm: object = None
def __init__(self) -> None: def __init__(self) -> None:
k2v = KnownLedge2Vector() k2v = KnownLedge2Vector()
self.vector_store = k2v.init_vector_store() self.vector_store = k2v.init_vector_store()
self.llm = VicunaLLM()
def get_answer(self, query): def get_similar_answer(self, query):
prompt_template = conv_qk_prompt_template
prompt = PromptTemplate( prompt = PromptTemplate(
template=prompt_template, template=conv_qa_prompt_template,
input_variables=["context", "question"] input_variables=["context", "question"]
) )
knownledge_chain = RetrievalQA.from_llm( retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
llm=self.llm, docs = retriever.get_relevant_documents(query=query)
retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}),
prompt=prompt context = [d.page_content for d in docs]
) result = prompt.format(context="\n".join(context), question=query)
knownledge_chain.return_source_documents = True return result
result = knownledge_chain({"query": query})
yield result

View File

@ -170,7 +170,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
query = state.messages[-2][1] query = state.messages[-2][1]
# prompt 中添加上下文提示 # prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector: if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(new_state.roles[1], None) new_state.append_message(new_state.roles[1], None)
@ -222,7 +223,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
state.messages[-1][-1] = output state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return return
time.sleep(0.02)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)" state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
@ -231,6 +232,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
state.messages[-1][-1] = state.messages[-1][-1][:-1] state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time() finish_tstamp = time.time()
logger.info(f"{output}") logger.info(f"{output}")