插件输出优化

This commit is contained in:
yhjun1026 2023-05-15 14:18:08 +08:00
parent d5b5fc4f9a
commit e2750fcea0
2 changed files with 45 additions and 41 deletions

View File

@ -69,7 +69,7 @@ def execute_ai_response_json(
arguments,
prompt,
)
result = f"Command {command_name} returned: " f"{command_result}"
result = f"{command_result}"
return result

View File

@ -20,7 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
from pilot.plugins import scan_plugins
from pilot.configs.config import Config
from pilot.commands.command_mange import CommandRegistry
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.commands.exception_not_commands import NotCommands
@ -60,14 +60,15 @@ priority = {
"vicuna-13b": "aaa"
}
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
contents = [dc.page_content for dc, _ in docs]
return "\n".join(contents)
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(
**DB_SETTINGS
@ -80,10 +81,12 @@ def gen_sqlgen_conversation(dbname):
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
def get_database_list():
mo = MySQLOperator(**DB_SETTINGS)
return mo.get_db_list()
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
@ -96,6 +99,8 @@ function() {
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
@ -113,6 +118,7 @@ def load_demo(url_params, request: gr.Request):
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
@ -125,9 +131,8 @@ def regenerate(state, request: gr.Request):
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def clear_history(request: gr.Request):
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = None
return (state, [], "") + (disable_btn,) * 5
@ -140,12 +145,13 @@ def add_text(state, text, request: gr.Request):
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
""" Default support 4000 tokens, if tokens too lang, we will cut off """
text = text[:4000]
text = text[:4000]
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
def post_process_code(code):
sep = "\n```"
if sep in code:
@ -156,6 +162,7 @@ def post_process_code(code):
code = sep.join(blocks)
return code
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
print("AUTO DB-GPT模式.")
@ -183,7 +190,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
# autogpt模式的第一轮对话需要 构建专属prompt
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query],
db_schemes=gen_sqlgen_conversation(dbname))
logger.info("[TEST]:" + system_prompt)
template_name = "auto_dbgpt_one_shot"
new_state = conv_templates[template_name].copy()
@ -210,20 +218,18 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# 第一轮对话需要加入提示Prompt
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
## 获取最后一次插件的返回
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
state.messages[0][0] = ""
state.messages[0][1] = ""
state.messages[-2][1] = follow_up_prompt
if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
@ -243,11 +249,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
print(response.json())
print(str(response))
try:
# response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}"""
# response = response.replace("\n", "\\n")
# response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}"""
# response = response.replace("\n", "\\)
text = response.text.strip()
text = text.rstrip()
respObj = json.loads(text)
@ -257,7 +258,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
respObj_ex = json.loads(xx)
if respObj_ex['error_code'] == 0:
ai_response = None
all_text = respObj_ex['text']
all_text = respObj_ex['text']
### 解析返回文本获取AI回复部分
tmpResp = all_text.split(state.sep)
last_index = -1
@ -277,11 +278,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
cfg.set_last_plugin_return(plugin_resp)
print(plugin_resp)
state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp
state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
except NotCommands as e:
print("命令执行:" + e.message)
state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response)
state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
else:
# 流式输出
@ -304,7 +305,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
except requests.exceptions.RequestException as e:
@ -333,8 +334,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
block_css = (
code_highlight_css
+ """
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@ -348,23 +349,26 @@ block_css = (
"""
)
def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False)
else:
return gr.update(visible=True)
def change_tab():
autogpt = True
autogpt = True
def build_single_model_ui():
notice_markdown = """
# DB-GPT
@ -396,7 +400,7 @@ def build_single_model_ui():
interactive=True,
label="最大输出Token数",
)
tabs= gr.Tabs()
tabs = gr.Tabs()
with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
with tab_sql:
@ -427,7 +431,7 @@ def build_single_model_ui():
with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"):
files = gr.File(label="添加文件",
files = gr.File(label="添加文件",
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
show_label=False
@ -436,11 +440,10 @@ def build_single_model_ui():
load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件",
file_count="directory",
show_label=False)
file_count="directory",
show_label=False)
load_folder_button = gr.Button("上传并加载到知识库")
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
with gr.Row():
@ -449,9 +452,9 @@ def build_single_model_ui():
show_label=False,
placeholder="Enter text and press ENTER",
visible=False,
).style(container=False)
).style(container=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value="发送", visible=False)
send_btn = gr.Button(value="发送", visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="重新生成", interactive=False)
@ -465,7 +468,7 @@ def build_single_model_ui():
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
textbox.submit(
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
@ -487,10 +490,10 @@ def build_single_model_ui():
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(
@ -520,6 +523,7 @@ def build_webdemo():
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
@ -553,7 +557,7 @@ if __name__ == "__main__":
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry =command_registry
cfg.command_registry = command_registry
logger.info(args)
demo = build_webdemo()
@ -561,4 +565,4 @@ if __name__ == "__main__":
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
)
)