mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
Merge branch 'dev' into llm_fxp
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import traceback
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
@@ -9,7 +9,6 @@ import shutil
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
@@ -30,18 +29,19 @@ from pilot.configs.model_config import (
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
|
||||
from pilot.conversation import (
|
||||
SeparatorStyle,
|
||||
conv_qa_prompt_template,
|
||||
conv_templates,
|
||||
conversation_sql_mode,
|
||||
conversation_types,
|
||||
chat_mode_title,
|
||||
default_conversation,
|
||||
)
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.common.plugins import scan_plugins
|
||||
|
||||
from pilot.prompts.generator import PluginPromptGenerator
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
@@ -95,6 +95,11 @@ default_knowledge_base_dialogue = get_lang_text(
|
||||
add_knowledge_base_dialogue = get_lang_text(
|
||||
"knowledge_qa_type_add_knowledge_base_dialogue"
|
||||
)
|
||||
|
||||
url_knowledge_dialogue = get_lang_text(
|
||||
"knowledge_qa_type_url_knowledge_dialogue"
|
||||
)
|
||||
|
||||
knowledge_qa_type_list = [
|
||||
llm_native_dialogue,
|
||||
default_knowledge_base_dialogue,
|
||||
@@ -111,19 +116,19 @@ def get_simlar(q):
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(**DB_SETTINGS)
|
||||
|
||||
message = ""
|
||||
|
||||
schemas = mo.get_schema(dbname)
|
||||
db_connect = CFG.local_db.get_session(dbname)
|
||||
schemas = CFG.local_db.table_simple_info(db_connect)
|
||||
for s in schemas:
|
||||
message += s["schema_info"] + ";"
|
||||
message += s+ ";"
|
||||
return get_lang_text("sql_schema_info").format(dbname, message)
|
||||
|
||||
|
||||
def get_database_list():
|
||||
mo = MySQLOperator(**DB_SETTINGS)
|
||||
return mo.get_db_list()
|
||||
def plugins_select_info():
|
||||
plugins_infos: dict = {}
|
||||
for plugin in CFG.plugins:
|
||||
plugins_infos.update({f"【{plugin._name}】=>{plugin._description}": plugin._name})
|
||||
return plugins_infos
|
||||
|
||||
|
||||
get_window_url_params = """
|
||||
@@ -210,285 +215,127 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
|
||||
def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
return ChatScene.ChatKnowledge
|
||||
elif mode == conversation_types["custome"] and not db_selector:
|
||||
return ChatScene.ChatNewKnowledge
|
||||
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"] and db_selector:
|
||||
return ChatScene.ChatWithDb
|
||||
|
||||
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
||||
def get_chat_mode(selected, param=None) -> ChatScene:
|
||||
if chat_mode_title['chat_use_plugin'] == selected:
|
||||
return ChatScene.ChatExecution
|
||||
elif chat_mode_title['knowledge_qa'] == selected:
|
||||
mode= param
|
||||
if mode == conversation_types["default_knownledge"]:
|
||||
return ChatScene.ChatKnowledge
|
||||
elif mode == conversation_types["custome"]:
|
||||
return ChatScene.ChatNewKnowledge
|
||||
elif mode == conversation_types["url"]:
|
||||
return ChatScene.ChatUrlKnowledge
|
||||
else:
|
||||
return ChatScene.ChatNormal
|
||||
else:
|
||||
return ChatScene.ChatNormal
|
||||
sql_mode= param
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
return ChatScene.ChatWithDbExecute
|
||||
else:
|
||||
return ChatScene.ChatWithDbQA
|
||||
|
||||
|
||||
def chatbot_callback(state, message):
|
||||
print(f"chatbot_callback:{message}")
|
||||
state.messages[-1][-1] = f"{message}"
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
|
||||
def http_bot(
|
||||
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||
state, selected, temperature, max_new_tokens, plugin_selector, mode, sql_mode, db_selector, url_input, knowledge_name
|
||||
):
|
||||
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
||||
start_tstamp = time.time()
|
||||
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
print(f"当前对话模式:{scene.value}")
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
if ChatScene.ChatWithDb == scene:
|
||||
logger.info("基于DB对话走新的模式!")
|
||||
logger.info(f"User message send!{state.conv_id},{selected},{plugin_selector},{mode},{sql_mode},{db_selector},{url_input}")
|
||||
if chat_mode_title['knowledge_qa'] == selected:
|
||||
scene: ChatScene = get_chat_mode(selected, mode)
|
||||
elif chat_mode_title['chat_use_plugin'] == selected:
|
||||
scene: ChatScene = get_chat_mode(selected)
|
||||
else:
|
||||
scene: ChatScene = get_chat_mode(selected, sql_mode)
|
||||
print(f"chat scene:{scene.value}")
|
||||
|
||||
if ChatScene.ChatWithDbExecute == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"user_input": state.last_user_input
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatWithDbQA == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
else:
|
||||
dbname = db_selector
|
||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||
if state.skip_next:
|
||||
# This generate call is skipped due to invalid inputs
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
return
|
||||
|
||||
if len(state.messages) == state.offset + 2:
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
# autogpt模式的第一轮对话需要 构建专属prompt
|
||||
system_prompt = auto_prompt.construct_first_prompt(
|
||||
fisrt_message=[query], db_schemes=gen_sqlgen_conversation(dbname)
|
||||
)
|
||||
logger.info("[TEST]:" + system_prompt)
|
||||
template_name = "auto_dbgpt_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
new_state.append_message(role="USER", message=system_prompt)
|
||||
# new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(
|
||||
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||
)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
state = new_state
|
||||
else:
|
||||
### 后续对话
|
||||
query = state.messages[-2][1]
|
||||
# 第一轮对话需要加入提示Prompt
|
||||
if mode == conversation_types["custome"]:
|
||||
template_name = "conv_one_shot"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(
|
||||
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||
)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
state = new_state
|
||||
elif sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
## 获取最后一次插件的返回
|
||||
follow_up_prompt = auto_prompt.construct_follow_up_prompt([query])
|
||||
state.messages[0][0] = ""
|
||||
state.messages[0][1] = ""
|
||||
state.messages[-2][1] = follow_up_prompt
|
||||
prompt = state.get_prompt()
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
vector_store_config = {
|
||||
"vector_store_name": "default",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||
state.messages[-2][1] = query
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
if mode == conversation_types["custome"] and not db_selector:
|
||||
print("vector store name: ", vector_store_name["vs_name"])
|
||||
vector_store_config = {
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
|
||||
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
|
||||
|
||||
state.messages[-2][1] = query
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
# Make requests
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"temperature": float(temperature),
|
||||
"max_new_tokens": int(max_new_tokens),
|
||||
"stop": state.sep
|
||||
if state.sep_style == SeparatorStyle.SINGLE
|
||||
else state.sep2,
|
||||
elif ChatScene.ChatExecution == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"plugin_selector": plugin_selector,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatNormal == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatKnowledge == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatNewKnowledge == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
"knowledge_name": knowledge_name
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
elif ChatScene.ChatUrlKnowledge == scene:
|
||||
chat_param = {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"chat_session_id": state.conv_id,
|
||||
"user_input": state.last_user_input,
|
||||
"url": url_input
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
|
||||
if sql_mode == conversation_sql_mode["auto_execute_ai_response"]:
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
print(response.json())
|
||||
print(str(response))
|
||||
try:
|
||||
text = response.text.strip()
|
||||
text = text.rstrip()
|
||||
respObj = json.loads(text)
|
||||
|
||||
xx = respObj["response"]
|
||||
xx = xx.strip(b"\x00".decode())
|
||||
respObj_ex = json.loads(xx)
|
||||
if respObj_ex["error_code"] == 0:
|
||||
ai_response = None
|
||||
all_text = respObj_ex["text"]
|
||||
### 解析返回文本,获取AI回复部分
|
||||
tmpResp = all_text.split(state.sep)
|
||||
last_index = -1
|
||||
for i in range(len(tmpResp)):
|
||||
if tmpResp[i].find("ASSISTANT:") != -1:
|
||||
last_index = i
|
||||
ai_response = tmpResp[last_index]
|
||||
ai_response = ai_response.replace("ASSISTANT:", "")
|
||||
ai_response = ai_response.replace("\n", "")
|
||||
ai_response = ai_response.replace("\_", "_")
|
||||
|
||||
print(ai_response)
|
||||
if ai_response == None:
|
||||
state.messages[-1][-1] = "ASSISTANT未能正确回复,回复结果为:\n" + all_text
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
else:
|
||||
plugin_resp = execute_ai_response_json(
|
||||
auto_prompt.prompt_generator, ai_response
|
||||
)
|
||||
cfg.set_last_plugin_return(plugin_resp)
|
||||
print(plugin_resp)
|
||||
state.messages[-1][-1] = (
|
||||
"Model推理信息:\n"
|
||||
+ ai_response
|
||||
+ "\n\nDB-GPT执行结果:\n"
|
||||
+ plugin_resp
|
||||
)
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
except NotCommands as e:
|
||||
print("命令执行:" + e.message)
|
||||
state.messages[-1][-1] = (
|
||||
"命令执行:" + e.message + "\n模型输出:\n" + str(ai_response)
|
||||
)
|
||||
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
||||
else:
|
||||
# 流式输出
|
||||
state.messages[-1][-1] = "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
||||
|
||||
try:
|
||||
# Stream output
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=20,
|
||||
)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
print("****************:", data)
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
output = post_process_code(output)
|
||||
state.messages[-1][-1] = output + "▌"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
) * 5
|
||||
else:
|
||||
output = (
|
||||
data["text"] + f" (error_code: {data['error_code']})"
|
||||
)
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
if not chat.prompt_template.stream_out:
|
||||
logger.info("not stream out, wait model response!")
|
||||
state.messages[-1][-1] = chat.nostream_call()
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
else:
|
||||
logger.info("stream out start!")
|
||||
try:
|
||||
stream_gen = chat.stream_call()
|
||||
for msg in stream_gen:
|
||||
state.messages[-1][-1] = msg
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
state.messages[-1][-1] = "Error:" + str(e)
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
# 记录运行日志
|
||||
finish_tstamp = time.time()
|
||||
logger.info(f"{output}")
|
||||
|
||||
with open(get_conv_log_filename(), "a") as fout:
|
||||
data = {
|
||||
"tstamp": round(finish_tstamp, 4),
|
||||
"type": "chat",
|
||||
"model": model_name,
|
||||
"start": round(start_tstamp, 4),
|
||||
"finish": round(start_tstamp, 4),
|
||||
"state": state.dict(),
|
||||
"ip": request.client.host,
|
||||
}
|
||||
fout.write(json.dumps(data) + "\n")
|
||||
|
||||
if state.messages[-1][-1].endwith("▌"):
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
block_css = (
|
||||
code_highlight_css
|
||||
@@ -515,15 +362,12 @@ def change_sql_mode(sql_mode):
|
||||
|
||||
|
||||
def change_mode(mode):
|
||||
if mode in [default_knowledge_base_dialogue, llm_native_dialogue]:
|
||||
return gr.update(visible=False)
|
||||
else:
|
||||
if mode in [add_knowledge_base_dialogue]:
|
||||
return gr.update(visible=True)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
|
||||
|
||||
def change_tab():
|
||||
autogpt = True
|
||||
|
||||
|
||||
def build_single_model_ui():
|
||||
notice_markdown = get_lang_text("db_gpt_introduction")
|
||||
@@ -552,7 +396,16 @@ def build_single_model_ui():
|
||||
interactive=True,
|
||||
label=get_lang_text("max_input_token_size"),
|
||||
)
|
||||
|
||||
tabs = gr.Tabs()
|
||||
|
||||
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
|
||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
||||
return evt.value
|
||||
|
||||
selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
||||
tabs.select(on_select, None, selected)
|
||||
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
|
||||
with tab_sql:
|
||||
@@ -572,11 +425,34 @@ def build_single_model_ui():
|
||||
get_lang_text("sql_generate_mode_none"),
|
||||
],
|
||||
show_label=False,
|
||||
value=get_lang_text("sql_generate_mode_none"),
|
||||
value=get_lang_text("sql_generate_mode_none")
|
||||
)
|
||||
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
|
||||
tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN")
|
||||
# tab_plugin.select(change_func)
|
||||
with tab_plugin:
|
||||
print("tab_plugin in...")
|
||||
with gr.Row(elem_id="plugin_selector"):
|
||||
# TODO
|
||||
plugin_selector = gr.Dropdown(
|
||||
label=get_lang_text("select_plugin"),
|
||||
choices=list(plugins_select_info().keys()),
|
||||
value="",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
type="value"
|
||||
).style(container=False)
|
||||
|
||||
def plugin_change(evt: gr.SelectData): # SelectData is a subclass of EventData
|
||||
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
||||
print(f"user plugin:{plugins_select_info().get(evt.value)}")
|
||||
return plugins_select_info().get(evt.value)
|
||||
|
||||
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
|
||||
plugin_selector.select(plugin_change, None, plugin_selected)
|
||||
|
||||
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
|
||||
with tab_qa:
|
||||
mode = gr.Radio(
|
||||
@@ -584,14 +460,25 @@ def build_single_model_ui():
|
||||
llm_native_dialogue,
|
||||
default_knowledge_base_dialogue,
|
||||
add_knowledge_base_dialogue,
|
||||
url_knowledge_dialogue,
|
||||
],
|
||||
show_label=False,
|
||||
value=llm_native_dialogue,
|
||||
)
|
||||
vs_setting = gr.Accordion(
|
||||
get_lang_text("configure_knowledge_base"), open=False
|
||||
get_lang_text("configure_knowledge_base"), open=False, visible=False
|
||||
)
|
||||
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
|
||||
|
||||
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True, visible=False)
|
||||
def show_url_input(evt:gr.SelectData):
|
||||
if evt.value == url_knowledge_dialogue:
|
||||
return gr.update(visible=True)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
|
||||
|
||||
|
||||
with vs_setting:
|
||||
vs_name = gr.Textbox(
|
||||
label=get_lang_text("new_klg_name"), lines=1, interactive=True
|
||||
@@ -639,10 +526,14 @@ def build_single_model_ui():
|
||||
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
|
||||
|
||||
gr.Markdown(learn_more_markdown)
|
||||
|
||||
params = [plugin_selected, mode, sql_mode, db_selector, url_input, vs_name]
|
||||
|
||||
|
||||
btn_list = [regenerate_btn, clear_btn]
|
||||
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, temperature, max_output_tokens] + params,
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
||||
@@ -651,7 +542,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, temperature, max_output_tokens]+ params,
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@@ -659,7 +550,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, selected, temperature, max_output_tokens]+ params,
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
vs_add.click(
|
||||
@@ -766,8 +657,8 @@ if __name__ == "__main__":
|
||||
|
||||
# 加载插件可执行命令
|
||||
command_categories = [
|
||||
"pilot.commands.audio_text",
|
||||
"pilot.commands.image_gen",
|
||||
"pilot.commands.built_in.audio_text",
|
||||
"pilot.commands.built_in.image_gen",
|
||||
]
|
||||
# 排除禁用命令
|
||||
command_categories = [
|
||||
|
Reference in New Issue
Block a user