插件输出优化

This commit is contained in:
yhjun1026 2023-05-15 14:18:08 +08:00
parent d5b5fc4f9a
commit e2750fcea0
2 changed files with 45 additions and 41 deletions

View File

@ -69,7 +69,7 @@ def execute_ai_response_json(
arguments,
prompt,
)
result = f"Command {command_name} returned: " f"{command_result}"
result = f"{command_result}"
return result

View File

@ -60,14 +60,15 @@ priority = {
"vicuna-13b": "aaa"
}
def get_simlar(q):
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents)
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
@ -80,10 +81,12 @@ def gen_sqlgen_conversation(dbname):
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
def get_database_list():
mo = MySQLOperator(**DB_SETTINGS)
return mo.get_db_list()
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
@ -96,6 +99,8 @@ function() {
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
@ -113,6 +118,7 @@ def load_demo(url_params, request: gr.Request):
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
@ -125,9 +131,8 @@ def regenerate(state, request: gr.Request):
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
@ -146,6 +151,7 @@ def add_text(state, text, request: gr.Request):
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
@ -156,6 +162,7 @@ def post_process_code(code):
code = sep.join(blocks)
return code
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
print("AUTO DB-GPT模式.")
@ -183,7 +190,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query],
db_schemes=gen_sqlgen_conversation(dbname))
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy()
@ -215,13 +223,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
state.messages[0][1] = ""
state.messages[-2][1] = follow_up_prompt
if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
@ -243,11 +249,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
print(response.json())
print(str(response))
try:
# response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}"""
# response = response.replace("\n", "\\n")
# response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}"""
# response = response.replace("\n", "\\)
text = response.text.strip()
text = text.rstrip()
respObj = json.loads(text)
@ -277,11 +278,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp)
state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp
state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e:
print("命令执行:" + e.message)
state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response)
state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
# 流式输出
@ -348,23 +349,26 @@ block_css = (
"""
)
def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False)
else:
return gr.update(visible=True)
def change_tab():
autogpt = True
def build_single_model_ui():
def build_single_model_ui():
notice_markdown = """
# DB-GPT
@ -396,7 +400,7 @@ def build_single_model_ui():
interactive=True,
label="最大输出Token数",
)
tabs= gr.Tabs()
tabs = gr.Tabs()
with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
with tab_sql:
@ -440,7 +444,6 @@ def build_single_model_ui():
show_label=False)
load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
@ -520,6 +523,7 @@ def build_webdemo():
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
@ -553,7 +557,7 @@ if __name__ == "__main__":
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry =command_registry
cfg.command_registry = command_registry
logger.info(args)
demo = build_webdemo()