add plugin mode

This commit is contained in:
yhjun1026
2023-05-30 17:20:37 +08:00
parent dd5fc529e2
commit 5150cfcf55
14 changed files with 212 additions and 153 deletions

View File

@@ -37,6 +37,7 @@ from pilot.conversation import (
conv_templates,
conversation_sql_mode,
conversation_types,
chat_mode_title,
default_conversation,
)
from pilot.common.plugins import scan_plugins
@@ -95,6 +96,11 @@ default_knowledge_base_dialogue = get_lang_text(
add_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_add_knowledge_base_dialogue"
)
url_knowledge_dialogue = get_lang_text(
"knowledge_qa_type_url_knowledge_dialogue"
)
knowledge_qa_type_list = [
llm_native_dialogue,
default_knowledge_base_dialogue,
@@ -115,7 +121,7 @@ def gen_sqlgen_conversation(dbname):
db_connect = CFG.local_db.get_session(dbname)
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s["schema_info"] + ";"
message += s+ ";"
return get_lang_text("sql_schema_info").format(dbname, message)
@@ -211,9 +217,9 @@ def post_process_code(code):
def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
if "插件模式" == selected:
if chat_mode_title['chat_use_plugin'] == selected:
return ChatScene.ChatExecution
elif "知识问答" == selected:
elif chat_mode_title['knowledge_qa'] == selected:
if mode == conversation_types["default_knownledge"]:
return ChatScene.ChatKnowledge
elif mode == conversation_types["custome"]:
@@ -226,37 +232,50 @@ def get_chat_mode(selected, mode, sql_mode, db_selector) -> ChatScene:
def http_bot(
state, selected, plugin_selector, mode, sql_mode, db_selector, temperature, max_new_tokens, request: gr.Request
state, selected, plugin_selector, mode, sql_mode, db_selector, url_input, temperature, max_new_tokens, request: gr.Request
):
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
start_tstamp = time.time()
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
scene:ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
print(f"now chat scene:{scene.value}")
model_name = CFG.LLM_MODEL
def chatbot_callback(state, message):
print(f"chatbot_callback:{message}")
state.messages[-1][-1] = f"{message}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
if ChatScene.ChatWithDb == scene:
logger.info("基于DB对话走新的模式")
logger.info("chat with db mode use new architecture design")
chat_param = {
"chat_session_id": state.conv_id,
"db_name": db_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.call(show_fn=chatbot_callback, state= state)
elif ChatScene.ChatExecution == scene:
logger.info("插件模式对话走新的模式")
logger.info("plugin mode use new architecture design")
chat_param = {
"chat_session_id": state.conv_id,
"plugin_selector": plugin_selector,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
state.messages[-1][-1] = f"{chat.current_ai_response()}"
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.call(chatbot_callback, state)
# def generate_numbers():
# for i in range(10):
# time.sleep(0.5)
# yield f"Message:{i}"
#
# def showMessage(message):
# return message
#
# for n in generate_numbers():
# state.messages[-1][-1] = n + "▌"
# yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
dbname = db_selector
@@ -284,30 +303,45 @@ def http_bot(
new_state.conv_id = uuid.uuid4().hex
state = new_state
else:
### 后续对话
query = state.messages[-2][1]
# 第一轮对话需要加入提示Prompt
if mode == conversation_types["custome"]:
template_name = "conv_one_shot"
new_state = conv_templates[template_name].copy()
# prompt 中添加上下文提示, 根据已有知识对话, 上下文提示是否也应该放在第一轮, 还是每一轮都添加上下文?
# 如果用户侧的问题跨度很大, 应该每一轮都加提示。
if db_selector:
new_state.append_message(
new_state.roles[0], gen_sqlgen_conversation(dbname) + query
)
new_state.append_message(new_state.roles[1], None)
else:
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
state = new_state
prompt = state.get_prompt()
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["default_knownledge"] and not db_selector:
vector_store_config = {
"vector_store_name": "default",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt()
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["custome"] and not db_selector:
persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb"
)
print("向量数据库持久化地址: ", persist_dir)
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["sentence-transforms"],
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
print("vector store name: ", vector_store_name["vs_name"])
vector_store_config = {
"vector_store_name": vector_store_name["vs_name"],
@@ -327,6 +361,27 @@ def http_bot(
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
if mode == conversation_types["url"] and url_input:
print("url: ", url_input)
vector_store_config = {
"vector_store_name": url_input,
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path=url_input,
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
payload = {
"model": model_name,
@@ -355,13 +410,24 @@ def http_bot(
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
output = data["text"][skip_echo_len:].strip()
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
@@ -371,56 +437,7 @@ def http_bot(
enable_btn,
)
return
try:
# Stream output
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
) * 5
else:
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
@@ -432,29 +449,29 @@ def http_bot(
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = (
code_highlight_css
+ """
code_highlight_css
+ """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@@ -477,15 +494,12 @@ def change_sql_mode(sql_mode):
def change_mode(mode):
if mode in [default_knowledge_base_dialogue, llm_native_dialogue]:
return gr.update(visible=False)
else:
if mode in [add_knowledge_base_dialogue]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_tab():
autogpt = True
def build_single_model_ui():
notice_markdown = get_lang_text("db_gpt_introduction")
@@ -548,15 +562,14 @@ def build_single_model_ui():
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
tab_plugin = gr.TabItem(get_lang_text("chat_use_plugin"), elem_id="PLUGIN")
# tab_plugin.select(change_func)
with tab_plugin:
print("tab_plugin in...")
with gr.Row(elem_id="plugin_selector"):
# TODO
plugin_selector = gr.Dropdown(
label="请选择插件",
label=get_lang_text("select_plugin"),
choices=list(plugins_select_info().keys()),
value="",
interactive=True,
@@ -578,6 +591,7 @@ def build_single_model_ui():
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
url_knowledge_dialogue,
],
show_label=False,
value=llm_native_dialogue,
@@ -586,6 +600,16 @@ def build_single_model_ui():
get_lang_text("configure_knowledge_base"), open=False
)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
url_input = gr.Textbox(label=get_lang_text("url_input_label"), lines=1, interactive=True)
def show_url_input(evt:gr.SelectData):
if evt.value == url_knowledge_dialogue:
return gr.update(visible=True)
else:
return gr.update(visible=False)
mode.select(fn=show_url_input, inputs=None, outputs=url_input)
with vs_setting:
vs_name = gr.Textbox(
label=get_lang_text("new_klg_name"), lines=1, interactive=True
@@ -636,7 +660,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@@ -645,7 +669,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
@@ -653,7 +677,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, selected, plugin_selected, mode, sql_mode, db_selector, temperature, max_output_tokens],
[state, selected, plugin_selected, mode, sql_mode, db_selector, url_input, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
vs_add.click(
@@ -760,8 +784,8 @@ if __name__ == "__main__":
# 加载插件可执行命令
command_categories = [
"pilot.commands.audio_text",
"pilot.commands.image_gen",
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# 排除禁用命令
command_categories = [