mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
插件输出优化
This commit is contained in:
parent
d5b5fc4f9a
commit
e2750fcea0
@ -69,7 +69,7 @@ def execute_ai_response_json(
|
||||
arguments,
|
||||
prompt,
|
||||
)
|
||||
result = f"Command {command_name} returned: " f"{command_result}"
|
||||
result = f"{command_result}"
|
||||
return result
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
@ -60,14 +60,15 @@ priority = {
|
||||
"vicuna-13b": "aaa"
|
||||
}
|
||||
|
||||
|
||||
def get_simlar(q):
|
||||
|
||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||
docs = docsearch.similarity_search_with_score(q, k=1)
|
||||
|
||||
contents = [dc.page_content for dc, _ in docs]
|
||||
return "\n".join(contents)
|
||||
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(
|
||||
**DB_SETTINGS
|
||||
@ -80,10 +81,12 @@ def gen_sqlgen_conversation(dbname):
|
||||
message += s["schema_info"] + ";"
|
||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||
|
||||
|
||||
def get_database_list():
|
||||
mo = MySQLOperator(**DB_SETTINGS)
|
||||
return mo.get_db_list()
|
||||
|
||||
|
||||
get_window_url_params = """
|
||||
function() {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
@ -96,6 +99,8 @@ function() {
|
||||
return url_params;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def load_demo(url_params, request: gr.Request):
|
||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||
|
||||
@ -113,6 +118,7 @@ def load_demo(url_params, request: gr.Request):
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True))
|
||||
|
||||
|
||||
def get_conv_log_filename():
|
||||
t = datetime.datetime.now()
|
||||
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
||||
@ -125,9 +131,8 @@ def regenerate(state, request: gr.Request):
|
||||
state.skip_next = False
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
def clear_history(request: gr.Request):
|
||||
|
||||
|
||||
def clear_history(request: gr.Request):
|
||||
logger.info(f"clear_history. ip: {request.client.host}")
|
||||
state = None
|
||||
return (state, [], "") + (disable_btn,) * 5
|
||||
@ -140,12 +145,13 @@ def add_text(state, text, request: gr.Request):
|
||||
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
||||
|
||||
""" Default support 4000 tokens, if tokens too lang, we will cut off """
|
||||
text = text[:4000]
|
||||
text = text[:4000]
|
||||
state.append_message(state.roles[0], text)
|
||||
state.append_message(state.roles[1], None)
|
||||
state.skip_next = False
|
||||
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
||||
|
||||
|
||||
def post_process_code(code):
|
||||
sep = "\n```"
|
||||
if sep in code:
|
||||
@ -156,6 +162,7 @@ def post_process_code(code):
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
|
||||
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
print("AUTO DB-GPT模式.")
|
||||
@ -183,7 +190,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
|
||||
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query],
|
||||
db_schemes=gen_sqlgen_conversation(dbname))
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
template_name = "auto_dbgpt_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
@ -210,20 +218,18 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
## 获取最后一次插件的返回
|
||||
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
|
||||
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
|
||||
state.messages[0][0] = ""
|
||||
state.messages[0][1] = ""
|
||||
state.messages[-2][1] = follow_up_prompt
|
||||
|
||||
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
query = state.messages[-2][1]
|
||||
knqa = KnownLedgeBaseQA()
|
||||
state.messages[-2][1] = knqa.get_similar_answer(query)
|
||||
|
||||
|
||||
prompt = state.get_prompt()
|
||||
|
||||
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
# Make requests
|
||||
@ -243,11 +249,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
print(response.json())
|
||||
print(str(response))
|
||||
try:
|
||||
# response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}"""
|
||||
# response = response.replace("\n", "\\n")
|
||||
|
||||
# response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}"""
|
||||
# response = response.replace("\n", "\\)
|
||||
text = response.text.strip()
|
||||
text = text.rstrip()
|
||||
respObj = json.loads(text)
|
||||
@ -257,7 +258,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
respObj_ex = json.loads(xx)
|
||||
if respObj_ex['error_code'] == 0:
|
||||
ai_response = None
|
||||
all_text = respObj_ex['text']
|
||||
all_text = respObj_ex['text']
|
||||
### 解析返回文本,获取AI回复部分
|
||||
tmpResp = all_text.split(state.sep)
|
||||
last_index = -1
|
||||
@ -277,11 +278,11 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
|
||||
cfg.set_last_plugin_return(plugin_resp)
|
||||
print(plugin_resp)
|
||||
state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp
|
||||
state.messages[-1][-1] = "Model推理信息:\n" + ai_response + "\n\nDB-GPT执行结果:\n" + plugin_resp
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
except NotCommands as e:
|
||||
print("命令执行:" + e.message)
|
||||
state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response)
|
||||
state.messages[-1][-1] = "命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
else:
|
||||
# 流式输出
|
||||
@ -304,7 +305,7 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
return
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
@ -333,8 +334,8 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
@ -348,23 +349,26 @@ block_css = (
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def change_sql_mode(sql_mode):
|
||||
if sql_mode in ["直接执行结果"]:
|
||||
return gr.update(visible=True)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
|
||||
def change_mode(mode):
|
||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||
return gr.update(visible=False)
|
||||
else:
|
||||
return gr.update(visible=True)
|
||||
|
||||
|
||||
def change_tab():
|
||||
autogpt = True
|
||||
|
||||
autogpt = True
|
||||
|
||||
|
||||
def build_single_model_ui():
|
||||
|
||||
notice_markdown = """
|
||||
# DB-GPT
|
||||
|
||||
@ -396,7 +400,7 @@ def build_single_model_ui():
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
)
|
||||
tabs= gr.Tabs()
|
||||
tabs = gr.Tabs()
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
with tab_sql:
|
||||
@ -427,7 +431,7 @@ def build_single_model_ui():
|
||||
with gr.Column() as doc2vec:
|
||||
gr.Markdown("向知识库中添加文件")
|
||||
with gr.Tab("上传文件"):
|
||||
files = gr.File(label="添加文件",
|
||||
files = gr.File(label="添加文件",
|
||||
file_types=[".txt", ".md", ".docx", ".pdf"],
|
||||
file_count="multiple",
|
||||
show_label=False
|
||||
@ -436,11 +440,10 @@ def build_single_model_ui():
|
||||
load_file_button = gr.Button("上传并加载到知识库")
|
||||
with gr.Tab("上传文件夹"):
|
||||
folder_files = gr.File(label="添加文件",
|
||||
file_count="directory",
|
||||
show_label=False)
|
||||
file_count="directory",
|
||||
show_label=False)
|
||||
load_folder_button = gr.Button("上传并加载到知识库")
|
||||
|
||||
|
||||
|
||||
with gr.Blocks():
|
||||
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
|
||||
with gr.Row():
|
||||
@ -449,9 +452,9 @@ def build_single_model_ui():
|
||||
show_label=False,
|
||||
placeholder="Enter text and press ENTER",
|
||||
visible=False,
|
||||
).style(container=False)
|
||||
).style(container=False)
|
||||
with gr.Column(scale=2, min_width=50):
|
||||
send_btn = gr.Button(value="发送", visible=False)
|
||||
send_btn = gr.Button(value="发送", visible=False)
|
||||
|
||||
with gr.Row(visible=False) as button_row:
|
||||
regenerate_btn = gr.Button(value="重新生成", interactive=False)
|
||||
@ -465,7 +468,7 @@ def build_single_model_ui():
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
|
||||
|
||||
textbox.submit(
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
@ -487,10 +490,10 @@ def build_single_model_ui():
|
||||
|
||||
def build_webdemo():
|
||||
with gr.Blocks(
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
url_params = gr.JSON(visible=False)
|
||||
(
|
||||
@ -520,6 +523,7 @@ def build_webdemo():
|
||||
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
@ -553,7 +557,7 @@ if __name__ == "__main__":
|
||||
for command_category in command_categories:
|
||||
command_registry.import_commands(command_category)
|
||||
|
||||
cfg.command_registry =command_registry
|
||||
cfg.command_registry = command_registry
|
||||
|
||||
logger.info(args)
|
||||
demo = build_webdemo()
|
||||
@ -561,4 +565,4 @@ if __name__ == "__main__":
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
).launch(
|
||||
server_name=args.host, server_port=args.port, share=args.share, max_threads=200,
|
||||
)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user