mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
fetch top3 similar answer
This commit is contained in:
parent
e4899ff7dd
commit
56e9cde86e
@ -22,7 +22,7 @@ LLM_MODEL_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
VECTOR_SEARCH_TOP_K = 5
|
||||
VECTOR_SEARCH_TOP_K = 3
|
||||
LLM_MODEL = "vicuna-13b"
|
||||
LIMIT_MODEL_CONCURRENCY = 5
|
||||
MAX_POSITION_EMBEDDINGS = 2048
|
||||
|
@ -147,7 +147,7 @@ conv_vicuna_v1 = Conversation(
|
||||
)
|
||||
|
||||
|
||||
conv_qk_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
||||
conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题。
|
||||
如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议:
|
||||
|
||||
已知内容:
|
||||
|
@ -10,33 +10,29 @@ from typing import Any, Mapping, Optional, List
|
||||
from langchain.llms.base import LLM
|
||||
from pilot.configs.model_config import *
|
||||
|
||||
class VicunaRequestLLM(LLM):
|
||||
class VicunaLLM(LLM):
|
||||
|
||||
vicuna_generate_path = "generate_stream"
|
||||
def _call(self, prompt: str, temperature: float, max_new_tokens: int, stop: Optional[List[str]] = None) -> str:
|
||||
|
||||
vicuna_generate_path = "generate"
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
if isinstance(stop, list):
|
||||
stop = stop + ["Observation:"]
|
||||
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"stop": stop
|
||||
}
|
||||
response = requests.post(
|
||||
url=urljoin(VICUNA_MODEL_SERVER, self.vicuna_generate_path),
|
||||
data=json.dumps(params),
|
||||
)
|
||||
response.raise_for_status()
|
||||
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
# if chunk:
|
||||
# data = json.loads(chunk.decode())
|
||||
# if data["error_code"] == 0:
|
||||
# output = data["text"][skip_echo_len:].strip()
|
||||
# output = self.post_process_code(output)
|
||||
# yield output
|
||||
return response.json()["response"]
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
yield output
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
@ -4,29 +4,44 @@
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
import gradio as gr
|
||||
from pilot.configs.model_config import *
|
||||
vicuna_base_uri = "http://192.168.31.114:21002/"
|
||||
vicuna_stream_path = "worker_generate_stream"
|
||||
vicuna_status_path = "worker_get_status"
|
||||
from pilot.conversation import conv_qa_prompt_template, conv_templates
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
def generate(prompt):
|
||||
vicuna_stream_path = "generate_stream"
|
||||
|
||||
def generate(query):
|
||||
|
||||
template_name = "conv_one_shot"
|
||||
state = conv_templates[template_name].copy()
|
||||
|
||||
pt = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
result = pt.format(context="This page covers how to use the Chroma ecosystem within LangChain. It is broken into two parts: installation and setup, and then references to specific Chroma wrappers.",
|
||||
question=query)
|
||||
|
||||
print(result)
|
||||
|
||||
state.append_message(state.roles[0], result)
|
||||
state.append_message(state.roles[1], None)
|
||||
|
||||
prompt = state.get_prompt()
|
||||
params = {
|
||||
"model": "vicuna-13b",
|
||||
"prompt": prompt,
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 512,
|
||||
"max_new_tokens": 1024,
|
||||
"stop": "###"
|
||||
}
|
||||
|
||||
sts_response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_status_path)
|
||||
)
|
||||
print(sts_response.text)
|
||||
|
||||
response = requests.post(
|
||||
url=urljoin(vicuna_base_uri, vicuna_stream_path), data=json.dumps(params)
|
||||
url=urljoin(VICUNA_MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
|
||||
)
|
||||
|
||||
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3
|
||||
@ -34,11 +49,10 @@ def generate(prompt):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
if data["error_code"] == 0:
|
||||
output = data["text"]
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield(output)
|
||||
|
||||
time.sleep(0.02)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(LLM_MODEL)
|
||||
with gr.Blocks() as demo:
|
||||
|
@ -3,31 +3,28 @@
|
||||
|
||||
from pilot.vector_store.file_loader import KnownLedge2Vector
|
||||
from langchain.prompts import PromptTemplate
|
||||
from pilot.conversation import conv_qk_prompt_template
|
||||
from pilot.conversation import conv_qa_prompt_template
|
||||
from langchain.chains import RetrievalQA
|
||||
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from pilot.model.vicuna_llm import VicunaLLM
|
||||
|
||||
class KnownLedgeBaseQA:
|
||||
|
||||
llm: object = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
k2v = KnownLedge2Vector()
|
||||
self.vector_store = k2v.init_vector_store()
|
||||
self.llm = VicunaLLM()
|
||||
|
||||
def get_answer(self, query):
|
||||
prompt_template = conv_qk_prompt_template
|
||||
def get_similar_answer(self, query):
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template=prompt_template,
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
knownledge_chain = RetrievalQA.from_llm(
|
||||
llm=self.llm,
|
||||
retriever=self.vector_store.as_retriever(search_kwargs={"k", VECTOR_SEARCH_TOP_K}),
|
||||
prompt=prompt
|
||||
)
|
||||
knownledge_chain.return_source_documents = True
|
||||
result = knownledge_chain({"query": query})
|
||||
yield result
|
||||
retriever = self.vector_store.as_retriever(search_kwargs={"k": VECTOR_SEARCH_TOP_K})
|
||||
docs = retriever.get_relevant_documents(query=query)
|
||||
|
||||
context = [d.page_content for d in docs]
|
||||
result = prompt.format(context="\n".join(context), question=query)
|
||||
return result
|
||||
|
@ -170,7 +170,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
||||
|
||||
query = state.messages[-2][1]
|
||||
|
||||
# prompt 中添加上下文提示
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
@ -179,7 +180,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
|
||||
|
||||
# try:
|
||||
# if not db_selector:
|
||||
# sim_q = get_simlar(query)
|
||||
@ -222,7 +223,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
time.sleep(0.02)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
@ -231,6 +232,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
# 记录运行日志
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user