插件输出优化

This commit is contained in:
yhjun1026 2023-05-15 14:18:08 +08:00
parent d5b5fc4f9a
commit e2750fcea0
2 changed files with 45 additions and 41 deletions

View File

@ -69,7 +69,7 @@ def execute_ai_response_json(
arguments, arguments,
prompt, prompt,
) )
result = f"Command {command_name} returned: " f"{command_result}" result = f"{command_result}"
return result return result

View File

@ -20,7 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
from pilot.plugins import scan_plugins from pilot.plugins import scan_plugins
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.commands.command_mange import CommandRegistry from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.auto_mode_prompt import AutoModePrompt from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator from pilot.prompts.generator import PromptGenerator
from pilot.commands.exception_not_commands import NotCommands from pilot.commands.exception_not_commands import NotCommands
@ -60,14 +60,15 @@ priority = {
"vicuna-13b": "aaa" "vicuna-13b": "aaa"
} }
def get_simlar(q): def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1) docs = docsearch.similarity_search_with_score(q, k=1)
contents = [dc.page_content for dc, _ in docs] contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents) return "\n".join(contents)
def gen_sqlgen_conversation(dbname): def gen_sqlgen_conversation(dbname):
mo = MySQLOperator( mo = MySQLOperator(
**DB_SETTINGS **DB_SETTINGS
@ -80,10 +81,12 @@ def gen_sqlgen_conversation(dbname):
message += s["schema_info"] + ";" message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n" return f"数据库{dbname}的Schema信息如下: {message}\n"
def get_database_list(): def get_database_list():
mo = MySQLOperator(**DB_SETTINGS) mo = MySQLOperator(**DB_SETTINGS)
return mo.get_db_list() return mo.get_db_list()
get_window_url_params = """ get_window_url_params = """
function() { function() {
const params = new URLSearchParams(window.location.search); const params = new URLSearchParams(window.location.search);
@ -96,6 +99,8 @@ function() {
return url_params; return url_params;
} }
""" """
def load_demo(url_params, request: gr.Request): def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
@ -113,6 +118,7 @@ def load_demo(url_params, request: gr.Request):
gr.Row.update(visible=True), gr.Row.update(visible=True),
gr.Accordion.update(visible=True)) gr.Accordion.update(visible=True))
def get_conv_log_filename(): def get_conv_log_filename():
t = datetime.datetime.now() t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
@ -125,9 +131,8 @@ def regenerate(state, request: gr.Request):
state.skip_next = False state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}") logger.info(f"clear_history. ip: {request.client.host}")
state = None state = None
return (state, [], "") + (disable_btn,) * 5 return (state, [], "") + (disable_btn,) * 5
@ -140,12 +145,13 @@ def add_text(state, text, request: gr.Request):
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5 return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
""" Default support 4000 tokens, if tokens too lang, we will cut off """ """ Default support 4000 tokens, if tokens too lang, we will cut off """
text = text[:4000] text = text[:4000]
state.append_message(state.roles[0], text) state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None) state.append_message(state.roles[1], None)
state.skip_next = False state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code): def post_process_code(code):
sep = "\n```" sep = "\n```"
if sep in code: if sep in code:
@ -156,6 +162,7 @@ def post_process_code(code):
code = sep.join(blocks) code = sep.join(blocks)
return code return code
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request): def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
print("AUTO DB-GPT模式.") print("AUTO DB-GPT模式.")
@ -183,7 +190,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# 第一轮对话需要加入提示Prompt # 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt # autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname)) system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query],
db_schemes=gen_sqlgen_conversation(dbname))
logger.info("[TEST]:" + system_prompt) logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot" template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy() new_state = conv_templates[template_name].copy()
@ -210,20 +218,18 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# 第一轮对话需要加入提示Prompt # 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]: if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
## 获取最后一次插件的返回 ## 获取最后一次插件的返回
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query]) follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
state.messages[0][0] = "" state.messages[0][0] = ""
state.messages[0][1] = "" state.messages[0][1] = ""
state.messages[-2][1] = follow_up_prompt state.messages[-2][1] = follow_up_prompt
if mode == conversation_types["default_knownledge"] and not db_selector: if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1] query = state.messages[-2][1]
knqa = KnownLedgeBaseQA() knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query) state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt() prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1 skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests # Make requests
@ -243,11 +249,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
print(response.json()) print(response.json())
print(str(response)) print(str(response))
try: try:
# response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}"""
# response = response.replace("\n", "\\n")
# response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}"""
# response = response.replace("\n", "\\)
text = response.text.strip() text = response.text.strip()
text = text.rstrip() text = text.rstrip()
respObj = json.loads(text) respObj = json.loads(text)
@ -257,7 +258,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
respObj_ex = json.loads(xx) respObj_ex = json.loads(xx)
if respObj_ex['error_code'] == 0: if respObj_ex['error_code'] == 0:
ai_response = None ai_response = None
all_text = respObj_ex['text'] all_text = respObj_ex['text']
### 解析返回文本获取AI回复部分 ### 解析返回文本获取AI回复部分
tmpResp = all_text.split(state.sep) tmpResp = all_text.split(state.sep)
last_index = -1 last_index = -1
@ -277,11 +278,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response) plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
cfg.set_last_plugin_return(plugin_resp) cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp) print(plugin_resp)
state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e: except NotCommands as e:
print("命令执行:" + e.message) print("命令执行:" + e.message)
state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response) state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else: else:
# 流式输出 # 流式输出
@ -304,7 +305,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
output = data["text"] + f" (error_code: {data['error_code']})" output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + ( yield (state, state.to_gradio_chatbot()) + (
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return return
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
@ -333,8 +334,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
block_css = ( block_css = (
code_highlight_css code_highlight_css
+ """ + """
pre { pre {
white-space: pre-wrap; /* Since CSS 2.1 */ white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@ -348,23 +349,26 @@ block_css = (
""" """
) )
def change_sql_mode(sql_mode): def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]: if sql_mode in ["直接执行结果"]:
return gr.update(visible=True) return gr.update(visible=True)
else: else:
return gr.update(visible=False) return gr.update(visible=False)
def change_mode(mode): def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]: if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False) return gr.update(visible=False)
else: else:
return gr.update(visible=True) return gr.update(visible=True)
def change_tab(): def change_tab():
autogpt = True autogpt = True
def build_single_model_ui(): def build_single_model_ui():
notice_markdown = """ notice_markdown = """
# DB-GPT # DB-GPT
@ -396,7 +400,7 @@ def build_single_model_ui():
interactive=True, interactive=True,
label="最大输出Token数", label="最大输出Token数",
) )
tabs= gr.Tabs() tabs = gr.Tabs()
with tabs: with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL") tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
with tab_sql: with tab_sql:
@ -427,7 +431,7 @@ def build_single_model_ui():
with gr.Column() as doc2vec: with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件") gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"): with gr.Tab("上传文件"):
files = gr.File(label="添加文件", files = gr.File(label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"], file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple", file_count="multiple",
show_label=False show_label=False
@ -436,11 +440,10 @@ def build_single_model_ui():
load_file_button = gr.Button("上传并加载到知识库") load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"): with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件", folder_files = gr.File(label="添加文件",
file_count="directory", file_count="directory",
show_label=False) show_label=False)
load_folder_button = gr.Button("上传并加载到知识库") load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks(): with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550) chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row(): with gr.Row():
@ -449,9 +452,9 @@ def build_single_model_ui():
show_label=False, show_label=False,
placeholder="Enter text and press ENTER", placeholder="Enter text and press ENTER",
visible=False, visible=False,
).style(container=False) ).style(container=False)
with gr.Column(scale=2, min_width=50): with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value="发送", visible=False) send_btn = gr.Button(value="发送", visible=False)
with gr.Row(visible=False) as button_row: with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="重新生成", interactive=False) regenerate_btn = gr.Button(value="重新生成", interactive=False)
@ -465,7 +468,7 @@ def build_single_model_ui():
[state, chatbot] + btn_list, [state, chatbot] + btn_list,
) )
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit( textbox.submit(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then( ).then(
@ -487,10 +490,10 @@ def build_single_model_ui():
def build_webdemo(): def build_webdemo():
with gr.Blocks( with gr.Blocks(
title="数据库智能助手", title="数据库智能助手",
# theme=gr.themes.Base(), # theme=gr.themes.Base(),
theme=gr.themes.Default(), theme=gr.themes.Default(),
css=block_css, css=block_css,
) as demo: ) as demo:
url_params = gr.JSON(visible=False) url_params = gr.JSON(visible=False)
( (
@ -520,6 +523,7 @@ def build_webdemo():
raise ValueError(f"Unknown model list mode: {args.model_list_mode}") raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo return demo
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
@ -553,7 +557,7 @@ if __name__ == "__main__":
for command_category in command_categories: for command_category in command_categories:
command_registry.import_commands(command_category) command_registry.import_commands(command_category)
cfg.command_registry =command_registry cfg.command_registry = command_registry
logger.info(args) logger.info(args)
demo = build_webdemo() demo = build_webdemo()
@ -561,4 +565,4 @@ if __name__ == "__main__":
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch( ).launch(
server_name=args.host, server_port=args.port, share=args.share, max_threads=200, server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
) )