mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 05:49:22 +00:00
同步服务接口
This commit is contained in:
@@ -21,11 +21,15 @@ from pilot.plugins import scan_plugins
|
||||
from pilot.configs.config import Config
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.prompts.first_conversation_prompt import FirstPrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
|
||||
from pilot.commands.exception_not_commands import NotCommands
|
||||
|
||||
from pilot.conversation import (
|
||||
default_conversation,
|
||||
conv_templates,
|
||||
conversation_types,
|
||||
conversation_sql_mode,
|
||||
SeparatorStyle
|
||||
)
|
||||
|
||||
@@ -152,11 +156,13 @@ def post_process_code(code):
|
||||
code = sep.join(blocks)
|
||||
return code
|
||||
|
||||
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
|
||||
# MOCk
|
||||
autogpt = True
|
||||
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
print("AUTO DB-GPT模式.")
|
||||
if sql_mode == conversation_sql_mode["dont_execute_ai_response"]:
|
||||
print("标准DB-GPT模式.")
|
||||
print("是否是AUTO-GPT模式.", autogpt)
|
||||
|
||||
start_tstamp = time.time()
|
||||
model_name = LLM_MODEL
|
||||
|
||||
@@ -167,17 +173,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
|
||||
|
||||
# TODO when tab mode is AUTO_GPT, Prompt need to rebuild.
|
||||
if len(state.messages) == state.offset + 2:
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
cfg = Config()
|
||||
first_prompt = FirstPrompt()
|
||||
first_prompt.command_registry = cfg.command_registry
|
||||
if(autogpt):
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query])
|
||||
system_prompt = first_prompt.construct_first_prompt(fisrt_message=[query], db_schemes= gen_sqlgen_conversation(dbname))
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
template_name = "auto_dbgpt_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
@@ -218,7 +225,13 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
auto_db_gpt_response(first_prompt.prompt_generator, payload)
|
||||
else:
|
||||
stream_ai_response(payload)
|
||||
|
||||
def stream_ai_response(payload):
|
||||
# 流式输出
|
||||
state.messages[-1][-1] = "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
|
||||
@@ -264,6 +277,18 @@ def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.
|
||||
}
|
||||
fout.write(json.dumps(data) + "\n")
|
||||
|
||||
|
||||
def auto_db_gpt_response( prompt: PromptGenerator, payload)->str:
|
||||
response = requests.post(urljoin(VICUNA_MODEL_SERVER, "generate"),
|
||||
headers=headers, json=payload, timeout=30)
|
||||
print(response)
|
||||
try:
|
||||
plugin_resp = execute_ai_response_json(prompt, response)
|
||||
print(plugin_resp)
|
||||
except NotCommands as e:
|
||||
print(str(e))
|
||||
return "auto_db_gpt_response!"
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
@@ -280,6 +305,12 @@ block_css = (
|
||||
"""
|
||||
)
|
||||
|
||||
def change_sql_mode(sql_mode):
|
||||
if sql_mode in ["直接执行结果"]:
|
||||
return gr.update(visible=True)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
def change_mode(mode):
|
||||
if mode in ["默认知识库对话", "LLM原生对话"]:
|
||||
return gr.update(visible=False)
|
||||
@@ -325,6 +356,9 @@ def build_single_model_ui():
|
||||
tabs= gr.Tabs()
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
|
||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
with tab_sql:
|
||||
# TODO A selector to choose database
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
@@ -383,7 +417,7 @@ def build_single_model_ui():
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
@@ -392,7 +426,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@@ -400,7 +434,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user