多场景对架构一期

This commit is contained in:
yhjun1026
2023-05-24 17:33:40 +08:00
parent b5139ea643
commit 2fc62c16ef
37 changed files with 2052 additions and 131 deletions

View File

@@ -20,13 +20,14 @@ from pilot.connections.mysql import MySQLOperator
from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.configs.model_config import LOGDIR, DATASETS_DIR
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.scene.base_chat import BaseChat
from pilot.commands.exception_not_commands import NotCommands
@@ -47,6 +48,8 @@ from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.commands.command import execute_ai_response_json
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
@@ -68,7 +71,8 @@ priority = {
}
# 加载插件
CFG= Config()
CFG = Config()
CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
"user": CFG.LOCAL_DB_USER,
@@ -125,6 +129,10 @@ def load_demo(url_params, request: gr.Request):
gr.Dropdown.update(choices=dbs)
state = default_conversation.copy()
unique_id = uuid.uuid1()
state.conv_id = str(unique_id)
return (state,
dropdown_update,
gr.Chatbot.update(visible=True),
@@ -164,6 +172,8 @@ def add_text(state, text, request: gr.Request):
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
### TODO
state.last_user_input = text
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
@@ -178,42 +188,52 @@ def post_process_code(code):
return code
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
print("AUTO DB-GPT模式.")
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
print("标准DB-GPT模式.")
print("是否是AUTO-GPT模式.", autogpt)
def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
if mode == conversation_types["default_knownledge"] and not db_selector:
return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"] and not db_selector:
return ChatScene.ChatNewKnowledge
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
return ChatScene.ChatWithDb
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
return ChatScene.ChatExecution
else:
return ChatScene.ChatNormal
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
start_tstamp = time.time()
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
model_name = CFG.LLM_MODEL
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if ChatScene.ChatWithDb == scene:
logger.info("基于DB对话走新的模式")
chat_param ={
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
# state.append_message(state.roles[1], chat.current_ai_response())
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
cfg = Config()
auto_prompt = AutoModePrompt()
auto_prompt.command_registry = cfg.command_registry
else:
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
query = state.messages[-2][1]
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
if len(state.messages) == state.offset + 2:
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query],
db_schemes=gen_sqlgen_conversation(dbname))
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy()
new_state.append_message(role='USER', message=system_prompt)
# new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
else:
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
@@ -225,99 +245,47 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
new_state.conv_id = uuid.uuid4().hex
state = new_state
else:
### 后续对话
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
## 获取最后一次插件的返回
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
state.messages[0][0] = ""
state.messages[0][1] = ""
state.messages[-2][1] = follow_up_prompt
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query)
new_state.conv_id = uuid.uuid4().hex
state = new_state
prompt = state.get_prompt()
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt()
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector:
persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb")
print("向量数据库持久化地址: ", persist_dir)
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, 1)
context = [d.page_content for d in docs]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector:
persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb")
print("向量数据库持久化地址: ", persist_dir)
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, 1)
context = [d.page_content for d in docs]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
payload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
}
logger.info(f"Requert: \n{payload}")
# Make requests
payload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
}
logger.info(f"Requert: \n{payload}")
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers, json=payload, timeout=120)
print(response.json())
print(str(response))
try:
text = response.text.strip()
text = text.rstrip()
respObj = json.loads(text)
xx = respObj['response']
xx = xx.strip(b'\x00'.decode())
respObj_ex = json.loads(xx)
if respObj_ex['error_code'] == 0:
ai_response = None
all_text = respObj_ex['text']
### 解析返回文本获取AI回复部分
tmpResp = all_text.split(state.sep)
last_index = -1
for i in range(len(tmpResp)):
if tmpResp[i].find('ASSISTANT:') != -1:
last_index = i
ai_response = tmpResp[last_index]
ai_response = ai_response.replace("ASSISTANT:", "")
ai_response = ai_response.replace("\n", "")
ai_response = ai_response.replace("\_", "_")
print(ai_response)
if ai_response == None:
state.messages[-1][-1] = "ASSISTANT未能正确回复回复结果为:\n" + all_text
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp)
state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e:
print("命令执行:" + e.message)
state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
# 流式输出
state.messages[-1][-1] = ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
@@ -413,6 +381,7 @@ def build_single_model_ui():
"""
state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
@@ -472,7 +441,7 @@ def build_single_model_ui():
folder_files = gr.File(label="添加文件夹",
accept_multiple_files=True,
file_count="directory",
show_label=False)
show_label=False)
load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks():
@@ -516,8 +485,8 @@ def build_single_model_ui():
[state, chatbot] + btn_list
)
vs_add.click(fn=save_vs_name, show_progress=True,
inputs=[vs_name],
outputs=[vs_name])
inputs=[vs_name],
outputs=[vs_name])
load_file_button.click(fn=knowledge_embedding_store,
show_progress=True,
inputs=[vs_name, files],
@@ -569,6 +538,7 @@ def save_vs_name(vs_name):
vector_store_name["vs_name"] = vs_name
return vs_name
def knowledge_embedding_store(vs_id, files):
# vs_path = os.path.join(VS_ROOT_PATH, vs_id)
if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)):
@@ -584,10 +554,10 @@ def knowledge_embedding_store(vs_id, files):
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
knowledge_embedding_client.knowledge_embedding()
logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
@@ -603,7 +573,7 @@ if __name__ == "__main__":
# 配置初始化
cfg = Config()
dbs = get_database_list()
dbs = cfg.local_db.get_database_list()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))