mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
add plugin mode
This commit is contained in:
@@ -30,7 +30,7 @@ from pilot.configs.model_config import (
|
||||
LOGDIR,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
from pilot.connections.mysql import MySQLOperator
|
||||
|
||||
from pilot.conversation import (
|
||||
SeparatorStyle,
|
||||
conv_qa_prompt_template,
|
||||
@@ -39,9 +39,9 @@ from pilot.conversation import (
|
||||
conversation_types,
|
||||
default_conversation,
|
||||
)
|
||||
from pilot.plugins import scan_plugins
|
||||
from pilot.prompts.auto_mode_prompt import AutoModePrompt
|
||||
from pilot.prompts.generator import PromptGenerator
|
||||
from pilot.common.plugins import scan_plugins
|
||||
|
||||
from pilot.prompts.generator import PluginPromptGenerator
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
from pilot.server.vectordb_qa import KnownLedgeBaseQA
|
||||
@@ -95,19 +95,14 @@ def get_simlar(q):
|
||||
|
||||
|
||||
def gen_sqlgen_conversation(dbname):
|
||||
mo = MySQLOperator(**DB_SETTINGS)
|
||||
|
||||
message = ""
|
||||
|
||||
schemas = mo.get_schema(dbname)
|
||||
db_connect = CFG.local_db.get_session(dbname)
|
||||
schemas = CFG.local_db.table_simple_info(db_connect)
|
||||
for s in schemas:
|
||||
message += s["schema_info"] + ";"
|
||||
return f"数据库{dbname}的Schema信息如下: {message}\n"
|
||||
|
||||
|
||||
def get_database_list():
|
||||
mo = MySQLOperator(**DB_SETTINGS)
|
||||
return mo.get_db_list()
|
||||
|
||||
|
||||
get_window_url_params = """
|
||||
@@ -127,7 +122,6 @@ function() {
|
||||
def load_demo(url_params, request: gr.Request):
|
||||
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
||||
|
||||
# dbs = get_database_list()
|
||||
dropdown_update = gr.Dropdown.update(visible=True)
|
||||
if dbs:
|
||||
gr.Dropdown.update(choices=dbs)
|
||||
@@ -137,13 +131,15 @@ def load_demo(url_params, request: gr.Request):
|
||||
unique_id = uuid.uuid1()
|
||||
state.conv_id = str(unique_id)
|
||||
|
||||
return (state,
|
||||
dropdown_update,
|
||||
gr.Chatbot.update(visible=True),
|
||||
gr.Textbox.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True))
|
||||
return (
|
||||
state,
|
||||
dropdown_update,
|
||||
gr.Chatbot.update(visible=True),
|
||||
gr.Textbox.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
gr.Row.update(visible=True),
|
||||
gr.Accordion.update(visible=True),
|
||||
)
|
||||
|
||||
|
||||
def get_conv_log_filename():
|
||||
@@ -203,30 +199,31 @@ def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
|
||||
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
|
||||
return ChatScene.ChatExecution
|
||||
else:
|
||||
return ChatScene.ChatNormal
|
||||
return ChatScene.ChatNormal
|
||||
|
||||
|
||||
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
|
||||
def http_bot(
|
||||
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
|
||||
):
|
||||
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
|
||||
start_tstamp = time.time()
|
||||
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
|
||||
print(f"当前对话模式:{scene.value}")
|
||||
model_name = CFG.LLM_MODEL
|
||||
|
||||
if ChatScene.ChatWithDb == scene:
|
||||
logger.info("基于DB对话走新的模式!")
|
||||
chat_param ={
|
||||
chat_param = {
|
||||
"chat_session_id": state.conv_id,
|
||||
"db_name": db_selector,
|
||||
"user_input": state.last_user_input
|
||||
"user_input": state.last_user_input,
|
||||
}
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
|
||||
chat.call()
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
state.messages[-1][-1] = f"{chat.current_ai_response()}"
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
else:
|
||||
|
||||
dbname = db_selector
|
||||
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
|
||||
if state.skip_next:
|
||||
@@ -242,7 +239,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
|
||||
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
|
||||
if db_selector:
|
||||
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
|
||||
new_state.append_message(
|
||||
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
|
||||
)
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
else:
|
||||
new_state.append_message(new_state.roles[0], query)
|
||||
@@ -251,7 +250,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
new_state.conv_id = uuid.uuid4().hex
|
||||
state = new_state
|
||||
|
||||
|
||||
prompt = state.get_prompt()
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
if mode == conversation_types["default_knownledge"] and not db_selector:
|
||||
@@ -263,16 +261,24 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
|
||||
|
||||
if mode == conversation_types["custome"] and not db_selector:
|
||||
persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb")
|
||||
persist_dir = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb"
|
||||
)
|
||||
print("向量数据库持久化地址: ", persist_dir)
|
||||
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
},
|
||||
)
|
||||
query = state.messages[-2][1]
|
||||
docs = knowledge_embedding_client.similar_search(query, 1)
|
||||
context = [d.page_content for d in docs]
|
||||
prompt_template = PromptTemplate(
|
||||
template=conv_qa_prompt_template,
|
||||
input_variables=["context", "question"]
|
||||
input_variables=["context", "question"],
|
||||
)
|
||||
result = prompt_template.format(context="\n".join(context), question=query)
|
||||
state.messages[-2][1] = result
|
||||
@@ -285,7 +291,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
"prompt": prompt,
|
||||
"temperature": float(temperature),
|
||||
"max_new_tokens": int(max_new_tokens),
|
||||
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
|
||||
"stop": state.sep
|
||||
if state.sep_style == SeparatorStyle.SINGLE
|
||||
else state.sep2,
|
||||
}
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
|
||||
@@ -295,8 +303,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
|
||||
try:
|
||||
# Stream output
|
||||
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers, json=payload, stream=True, timeout=20)
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=20,
|
||||
)
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
@@ -309,12 +322,23 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
|
||||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||||
state.messages[-1][-1] = output
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
|
||||
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
||||
yield (state, state.to_gradio_chatbot()) + (
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
disable_btn,
|
||||
enable_btn,
|
||||
enable_btn,
|
||||
)
|
||||
return
|
||||
|
||||
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
||||
@@ -405,10 +429,14 @@ def build_single_model_ui():
|
||||
interactive=True,
|
||||
label="最大输出Token数",
|
||||
)
|
||||
|
||||
|
||||
tabs = gr.Tabs()
|
||||
|
||||
with tabs:
|
||||
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
|
||||
with tab_sql:
|
||||
print("tab_sql in...")
|
||||
# TODO A selector to choose database
|
||||
with gr.Row(elem_id="db_selector"):
|
||||
db_selector = gr.Dropdown(
|
||||
@@ -423,8 +451,23 @@ def build_single_model_ui():
|
||||
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
|
||||
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
|
||||
|
||||
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
|
||||
with tab_plugin:
|
||||
print("tab_plugin in...")
|
||||
with gr.Row(elem_id="plugin_selector"):
|
||||
# TODO
|
||||
plugin_selector = gr.Dropdown(
|
||||
label="请选择插件",
|
||||
choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""],
|
||||
value="datadance-ddl-excutor",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
).style(container=False)
|
||||
|
||||
|
||||
tab_qa = gr.TabItem("知识问答", elem_id="QA")
|
||||
with tab_qa:
|
||||
print("tab_qa in...")
|
||||
mode = gr.Radio(
|
||||
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
|
||||
)
|
||||
@@ -483,7 +526,7 @@ def build_single_model_ui():
|
||||
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
|
||||
).then(
|
||||
http_bot,
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
|
||||
[state, chatbot] + btn_list,
|
||||
)
|
||||
|
||||
@@ -573,7 +616,6 @@ def knowledge_embedding_store(vs_id, files):
|
||||
)
|
||||
knowledge_embedding_client.knowledge_embedding()
|
||||
|
||||
|
||||
logger.info("knowledge embedding success")
|
||||
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")
|
||||
|
||||
|
Reference in New Issue
Block a user