add plugin mode

This commit is contained in:
yhjun1026
2023-05-29 19:32:20 +08:00
parent 52da74c54a
commit 20edf6daaa
45 changed files with 1202 additions and 804 deletions

View File

@@ -30,7 +30,7 @@ from pilot.configs.model_config import (
LOGDIR,
VECTOR_SEARCH_TOP_K,
)
from pilot.connections.mysql import MySQLOperator
from pilot.conversation import (
SeparatorStyle,
conv_qa_prompt_template,
@@ -39,9 +39,9 @@ from pilot.conversation import (
conversation_types,
default_conversation,
)
from pilot.plugins import scan_plugins
from pilot.prompts.auto_mode_prompt import AutoModePrompt
from pilot.prompts.generator import PromptGenerator
from pilot.common.plugins import scan_plugins
from pilot.prompts.generator import PluginPromptGenerator
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
from pilot.server.vectordb_qa import KnownLedgeBaseQA
@@ -95,19 +95,14 @@ def get_simlar(q):
def gen_sqlgen_conversation(dbname):
mo = MySQLOperator(**DB_SETTINGS)
message = ""
schemas = mo.get_schema(dbname)
db_connect = CFG.local_db.get_session(dbname)
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
def get_database_list():
mo = MySQLOperator(**DB_SETTINGS)
return mo.get_db_list()
get_window_url_params = """
@@ -127,7 +122,6 @@ function() {
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
# dbs = get_database_list()
dropdown_update = gr.Dropdown.update(visible=True)
if dbs:
gr.Dropdown.update(choices=dbs)
@@ -137,13 +131,15 @@ def load_demo(url_params, request: gr.Request):
unique_id = uuid.uuid1()
state.conv_id = str(unique_id)
return (state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True))
return (
state,
dropdown_update,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.Button.update(visible=True),
gr.Row.update(visible=True),
gr.Accordion.update(visible=True),
)
def get_conv_log_filename():
@@ -203,30 +199,31 @@ def get_chat_mode(mode, sql_mode, db_selector) -> ChatScene:
elif mode == conversation_types["auto_execute_plugin"] and not db_selector:
return ChatScene.ChatExecution
else:
return ChatScene.ChatNormal
return ChatScene.ChatNormal
def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request):
def http_bot(
state, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
):
logger.info(f"User message send!{state.conv_id},{sql_mode},{db_selector}")
start_tstamp = time.time()
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
scene: ChatScene = get_chat_mode(mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
model_name = CFG.LLM_MODEL
if ChatScene.ChatWithDb == scene:
logger.info("基于DB对话走新的模式")
chat_param ={
chat_param = {
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = f"{chat.current_ai_response()}"
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
@@ -242,7 +239,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query)
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
@@ -251,7 +250,6 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
new_state.conv_id = uuid.uuid4().hex
state = new_state
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
@@ -263,16 +261,24 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector:
persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb")
persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb"
)
print("向量数据库持久化地址: ", persist_dir)
knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH})
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, 1)
context = [d.page_content for d in docs]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"]
input_variables=["context", "question"],
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
@@ -285,7 +291,9 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": int(max_new_tokens),
"stop": state.sep if state.sep_style == SeparatorStyle.SINGLE else state.sep2,
"stop": state.sep
if state.sep_style == SeparatorStyle.SINGLE
else state.sep2,
}
logger.info(f"Requert: \n{payload}")
@@ -295,8 +303,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
try:
# Stream output
response = requests.post(urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, json=payload, stream=True, timeout=20)
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
@@ -309,12 +322,23 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
@@ -405,10 +429,14 @@ def build_single_model_ui():
interactive=True,
label="最大输出Token数",
)
tabs = gr.Tabs()
with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
with tab_sql:
print("tab_sql in...")
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
@@ -423,8 +451,23 @@ def build_single_model_ui():
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
with tab_plugin:
print("tab_plugin in...")
with gr.Row(elem_id="plugin_selector"):
# TODO
plugin_selector = gr.Dropdown(
label="请选择插件",
choices=[""" [datadance-ddl-excutor]->use datadance deal the ddl task """, """[file-writer]-file read and write """, """ [image-excutor]-> image build"""],
value="datadance-ddl-excutor",
interactive=True,
show_label=True,
).style(container=False)
tab_qa = gr.TabItem("知识问答", elem_id="QA")
with tab_qa:
print("tab_qa in...")
mode = gr.Radio(
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
)
@@ -483,7 +526,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
@@ -573,7 +616,6 @@ def knowledge_embedding_store(vs_id, files):
)
knowledge_embedding_client.knowledge_embedding()
logger.info("knowledge embedding success")
return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb")