mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 13:10:29 +00:00
多轮对话使用插件
This commit is contained in:
@@ -20,7 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
@@ -174,28 +174,25 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
return
|
||||
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
|
||||
auto_prompt = AutoModePrompt()
|
||||
auto_prompt.command_registry = cfg.command_registry
|
||||
|
||||
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||
if len(state.messages) == state.offset + 2:
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
first_prompt.command_registry = cfg.command_registry
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
|
||||
system_prompt = auto_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
template_name = "auto_dbgpt_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
new_state.append_message(role='USER', message=system_prompt)
|
||||
# new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
|
||||
if not autogpt:
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
@@ -205,7 +202,20 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
state = new_state
|
||||
else:
|
||||
### 后续对话
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
## 获取最后一次插件的返回
|
||||
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
|
||||
state.messages[0][0] = ""
|
||||
state.messages[0][1] = ""
|
||||
state.messages[-2][1] = follow_up_prompt
|
||||
|
||||
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
query = state.messages[-2][1]
|
||||
knqa = KnownLedgeBaseQA()
|
||||
@@ -228,20 +238,50 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
|
||||
headers=headers, json=payload, timeout=30)
|
||||
headers=headers, json=payload, timeout=120)
|
||||
|
||||
print(response.json())
|
||||
print(str(response))
|
||||
try:
|
||||
# response = """{"thoughts":{"text":"thought","reasoning":"reasoning","plan":"- short bulleted\n- list that conveys\n- long-term plan","criticism":"constructive self-criticism","speak":"thoughts summary to say to user"},"command":{"name":"db_sql_executor","args":{"sql":"select count(*) as user_count from users u where create_time >= DATE_SUB(NOW(), INTERVAL 1 MONTH);"}}}"""
|
||||
# response = response.replace("\n", "\\n")
|
||||
plugin_resp = execute_ai_response_json(first_prompt.prompt_generator, response)
|
||||
print(plugin_resp)
|
||||
state.messages[-1][-1] = "DB-GPT执行结果:\n" + plugin_resp
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
|
||||
# response = """{"thoughts":{"text":"In order to get the number of users who have grown in the last three days, I need to analyze the create\_time of each user and see if it is within the last three days. I will use the SQL query to filter the users who have created their account in the last three days.","reasoning":"I can use the SQL query to filter the users who have created their account in the last three days. I will get the current date and then subtract three days from it, and then use this as the filter for the query. This will give me the number of users who have created their account in the last three days.","plan":"- Get the current date and subtract three days from it\n- Use the SQL query to filter the users who have created their account in the last three days\n- Count the number of users who match the filter to get the number of users who have grown in the last three days","criticism":"None"},"command":{"name":"db_sql_executor","args":{"sql":"SELECT COUNT(DISTINCT(ID)) FROM users WHERE create_time >= DATE_SUB(NOW(), INTERVAL 3 DAY);"}}}"""
|
||||
# response = response.replace("\n", "\\)
|
||||
text = response.text.strip()
|
||||
text = text.rstrip()
|
||||
respObj = json.loads(text)
|
||||
|
||||
xx = respObj['response']
|
||||
xx = xx.strip(b'\x00'.decode())
|
||||
respObj_ex = json.loads(xx)
|
||||
if respObj_ex['error_code'] == 0:
|
||||
ai_response = None
|
||||
all_text = respObj_ex['text']
|
||||
### 解析返回文本,获取AI回复部分
|
||||
tmpResp = all_text.split(state.sep)
|
||||
last_index = -1
|
||||
for i in range(len(tmpResp)):
|
||||
if tmpResp[i].find('ASSISTANT:') != -1:
|
||||
last_index = i
|
||||
ai_response = tmpResp[last_index]
|
||||
ai_response = ai_response.replace("ASSISTANT:", "")
|
||||
ai_response = ai_response.replace("\n", "")
|
||||
ai_response = ai_response.replace("\_", "_")
|
||||
|
||||
print(ai_response)
|
||||
if ai_response == None:
|
||||
state.messages[-1][-1] = "ASSISTANT未能正确回复,回复结果为:\n" + all_text
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
else:
|
||||
plugin_resp = execute_ai_response_json(auto_prompt.prompt_generator, ai_response)
|
||||
cfg.set_last_plugin_return(plugin_resp)
|
||||
print(plugin_resp)
|
||||
state.messages[-1][-1] = "Model推理信息:\n"+ ai_response +"\n\nDB-GPT执行结果:\n" + plugin_resp
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
except NotCommands as e:
|
||||
print("命令执行:" + str(e))
|
||||
state.messages[-1][-1] = "命令执行:" + str(e) +"\n模型输出:\n" + str(response)
|
||||
print("命令执行:" + e.message)
|
||||
state.messages[-1][-1] = "命令执行:" + e.message +"\n模型输出:\n" + str(ai_response)
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
else:
|
||||
# 流式输出
|
||||
@@ -350,8 +390,8 @@ def build_single_model_ui():
|
||||
|
||||
max_output_tokens = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=4096,
|
||||
value=2048,
|
||||
maximum=1024,
|
||||
value=1024,
|
||||
step=64,
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
@@ -359,9 +399,6 @@ def build_single_model_ui():
|
||||
tabs= gr.Tabs()
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
with tab_sql:
|
||||
# TODO A selector to choose database
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
@@ -370,7 +407,11 @@ def build_single_model_ui():
|
||||
choices=dbs,
|
||||
value=dbs[0] if len(models) > 0 else "",
|
||||
interactive=True,
|
||||
show_label=True).style(container=False)
|
||||
show_label=True).style(container=False)
|
||||
|
||||
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
tab_auto = gr.TabItem("AUTO-GPT", elem_id="auto")
|
||||
with tab_auto:
|
||||
gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
|
Reference in New Issue
Block a user