From 55b3702aeb112ca8f27465d41a6f706c64238a2b Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 7 May 2023 17:58:24 +0800 Subject: [PATCH] knownledge base answer fix --- pilot/conversation.py | 6 ++++++ pilot/server/vectordb_qa.py | 1 - pilot/server/webserver.py | 29 +++++++++++++++-------------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/pilot/conversation.py b/pilot/conversation.py index 88d5ca591..2dc8df2b9 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -158,6 +158,12 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回 default_conversation = conv_one_shot +conversation_types = { + "native": "LLM原生对话", + "default_knownledge": "默认知识库对话", + "custome": "新增知识库对话", +} + conv_templates = { "conv_one_shot": conv_one_shot, "vicuna_v1": conv_vicuna_v1, diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py index f0b33f346..71a9b881d 100644 --- a/pilot/server/vectordb_qa.py +++ b/pilot/server/vectordb_qa.py @@ -4,7 +4,6 @@ from pilot.vector_store.file_loader import KnownLedge2Vector from langchain.prompts import PromptTemplate from pilot.conversation import conv_qa_prompt_template -from langchain.chains import RetrievalQA from pilot.configs.model_config import VECTOR_SEARCH_TOP_K from pilot.model.vicuna_llm import VicunaLLM diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 9d77aa50b..b31bfde7f 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -11,6 +11,7 @@ import datetime import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS +from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql_conn import MySQLOperator from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st @@ -19,6 +20,7 @@ from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, D from pilot.conversation import ( default_conversation, conv_templates, + conversation_types, SeparatorStyle ) @@ -149,7 +151,7 @@ def post_process_code(code): code = sep.join(blocks) return code -def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Request): +def http_bot(state, mode, db_selector, temperature, max_new_tokens, request: gr.Request): start_tstamp = time.time() model_name = LLM_MODEL @@ -180,14 +182,12 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques new_state.append_message(new_state.roles[0], query) new_state.append_message(new_state.roles[1], None) state = new_state - - # try: - # if not db_selector: - # sim_q = get_simlar(query) - # print("********vector similar info*************: ", sim_q) - # state.append_message(new_state.roles[0], sim_q + query) - # except Exception as e: - # print(e) + + if mode == conversation_types["default_knownledge"] and not db_selector: + query = state.messages[-2][1] + knqa = KnownLedgeBaseQA() + state.messages[-2][1] = knqa.get_similar_answer(query) + prompt = state.get_prompt() @@ -268,7 +268,7 @@ def change_tab(tab): pass def change_mode(mode): - if mode == "默认知识库对话": + if mode in ["默认知识库对话", "LLM原生对话"]: return gr.update(visible=False) else: return gr.update(visible=True) @@ -320,7 +320,8 @@ def build_single_model_ui(): show_label=True).style(container=False) with gr.TabItem("知识问答", elem_id="QA"): - mode = gr.Radio(["默认知识库对话", "新增知识库"], show_label=False, value="默认知识库对话") + + mode = gr.Radio(["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话") vs_setting = gr.Accordion("配置知识库", open=False) mode.change(fn=change_mode, inputs=mode, outputs=vs_setting) with vs_setting: @@ -365,7 +366,7 @@ def build_single_model_ui(): btn_list = [regenerate_btn, clear_btn] regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( http_bot, - [state, db_selector, temperature, max_output_tokens], + [state, mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) @@ -374,7 +375,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, db_selector, temperature, max_output_tokens], + [state, mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list, ) @@ -382,7 +383,7 @@ def build_single_model_ui(): add_text, [state, textbox], [state, chatbot, textbox] + btn_list ).then( http_bot, - [state, db_selector, temperature, max_output_tokens], + [state, mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list )