mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 06:30:02 +00:00
add plugin mode
This commit is contained in:
@@ -103,6 +103,11 @@ def gen_sqlgen_conversation(dbname):
|
||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||
|
||||
|
||||
def plugins_select_info():
|
||||
plugins_infos: dict = {}
|
||||
for plugin in CFG.plugins:
|
||||
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
||||
return plugins_infos
|
||||
|
||||
|
||||
get_window_url_params = """
|
||||
@@ -188,26 +193,27 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
|
||||
def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
return ChatScene.ChatKnowledge
|
||||
elif mode == conversation_types["custome"] and not db_selector:
|
||||
return ChatScene.ChatNewKnowledge
|
||||
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
|
||||
return ChatScene.ChatWithDb
|
||||
|
||||
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
||||
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
|
||||
if "插件模式" == selected:
|
||||
return ChatScene.ChatExecution
|
||||
elif "知识问答" == selected:
|
||||
if mode == conversation_types["default_knownledge"]:
|
||||
return ChatScene.ChatKnowledge
|
||||
elif mode == conversation_types["custome"]:
|
||||
return ChatScene.ChatNewKnowledge
|
||||
else:
|
||||
return ChatScene.ChatNormal
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
|
||||
return ChatScene.ChatWithDb
|
||||
|
||||
return ChatScene.ChatNormal
|
||||
|
||||
|
||||
def http_bot(
|
||||
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||
state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||
):
|
||||
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
||||
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
|
||||
start_tstamp = time.time()
|
||||
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
scene: ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
|
||||
print(f"当前对话模式:{scene.value}")
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
@@ -216,6 +222,17 @@ def http_bot(
|
||||
chat_param = {
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"current_user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
elif ChatScene.ChatExecution == scene:
|
||||
logger.info("插件模式对话走新的模式!")
|
||||
chat_param = {
|
||||
"chat_session_id": state.conv_id,
|
||||
"plugin_selector": plugin_selector,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
@@ -362,8 +379,8 @@ def http_bot(
|
||||
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
+ """
|
||||
code_highlight_css
|
||||
+ """
|
||||
pre {
|
||||
white-space: pre-wrap; /* Since CSS 2.1 */
|
||||
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
|
||||
@@ -396,6 +413,11 @@ def change_tab():
|
||||
autogpt = True
|
||||
|
||||
|
||||
def change_func(xx):
|
||||
print("123")
|
||||
print(str(xx))
|
||||
|
||||
|
||||
def build_single_model_ui():
|
||||
notice_markdown = """
|
||||
# DB-GPT
|
||||
@@ -430,11 +452,18 @@ def build_single_model_ui():
|
||||
label="最大输出Token数",
|
||||
)
|
||||
|
||||
|
||||
tabs = gr.Tabs()
|
||||
|
||||
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
|
||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
||||
return evt.value
|
||||
|
||||
selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
||||
tabs.select(on_select, None, selected)
|
||||
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
tab_sql.select(on_select, None, None)
|
||||
with tab_sql:
|
||||
print("tab_sql in...")
|
||||
# TODO A selector to choose database
|
||||
@@ -452,18 +481,26 @@ def build_single_model_ui():
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
|
||||
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
|
||||
# tab_plugin.select(change_func)
|
||||
with tab_plugin:
|
||||
print("tab_plugin in...")
|
||||
with gr.Row(elem_id="plugin_selector"):
|
||||
# TODO
|
||||
plugin_selector = gr.Dropdown(
|
||||
label="请选择插件",
|
||||
choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""],
|
||||
value="datadance-ddl-excutor",
|
||||
choices=list(plugins_select_info().keys()),
|
||||
value="",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
type="value"
|
||||
).style(container=False)
|
||||
|
||||
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
|
||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
||||
return plugins_select_info().get(evt.value)
|
||||
|
||||
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
||||
plugin_selector.select(plugin_change, None, plugin_selected)
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
with tab_qa:
|
||||
@@ -517,7 +554,7 @@ def build_single_model_ui():
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
@@ -526,7 +563,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@@ -534,7 +571,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
vs_add.click(
|
||||
@@ -557,10 +594,10 @@ def build_single_model_ui():
|
||||
|
||||
def build_webdemo():
|
||||
with gr.Blocks(
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
title="数据库智能助手",
|
||||
# theme=gr.themes.Base(),
|
||||
theme=gr.themes.Default(),
|
||||
css=block_css,
|
||||
) as demo:
|
||||
url_params = gr.JSON(visible=False)
|
||||
(
|
||||
|
Reference in New Issue
Block a user