knownledge base answer fix

This commit is contained in:
csunny
2023-05-07 17:58:24 +08:00
parent 56e9cde86e
commit 55b3702aeb
3 changed files with 21 additions and 15 deletions

View File

@@ -11,6 +11,7 @@ import datetime
import requests
from urllib.parse import urljoin
from pilot.configs.model_config import DB_SETTINGS
from pilot.server.vectordb_qa import KnownLedgeBaseQA
from pilot.connections.mysql_conn import MySQLOperator
from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st
@@ -19,6 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D
from pilot.conversation import (
default_conversation,
conv_templates,
conversation_types,
SeparatorStyle
)
@@ -149,7 +151,7 @@ def post_process_code(code):
code = sep.join(blocks)
return code
def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request):
def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request):
start_tstamp = time.time()
model_name = LLM_MODEL
@@ -180,14 +182,12 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques
new_state.append_message(new_state.roles[0], query)
new_state.append_message(new_state.roles[1], None)
state = new_state
# try:
# if not db_selector:
# sim_q = get_simlar(query)
# print("********vector similar info*************: ", sim_q)
# state.append_message(new_state.roles[0], sim_q + query)
# except Exception as e:
# print(e)
if mode == conversation_types["default_knownledge"] and not db_selector:
query = state.messages[-2][1]
knqa = KnownLedgeBaseQA()
state.messages[-2][1] = knqa.get_similar_answer(query)
prompt = state.get_prompt()
@@ -268,7 +268,7 @@ def change_tab(tab):
pass
def change_mode(mode):
if mode == "默认知识库对话":
if mode in ["默认知识库对话", "LLM原生对话"]:
return gr.update(visible=False)
else:
return gr.update(visible=True)
@@ -320,7 +320,8 @@ def build_single_model_ui():
show_label=True).style(container=False)
with gr.TabItem("知识问答", elem_id="QA"):
mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话")
mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话")
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
@@ -365,7 +366,7 @@ def build_single_model_ui():
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
http_bot,
[state, db_selector, temperature, max_output_tokens],
[state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
@@ -374,7 +375,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, db_selector, temperature, max_output_tokens],
[state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list,
)
@@ -382,7 +383,7 @@ def build_single_model_ui():
add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(
http_bot,
[state, db_selector, temperature, max_output_tokens],
[state, mode, db_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list
)