From c1758f030b50605a6e889577400c6000b01b36b3 Mon Sep 17 00:00:00 2001 From: csunny Date: Sat, 6 May 2023 00:41:35 +0800 Subject: [PATCH 1/2] knownledge based qa --- pilot/app.py | 4 ++-- pilot/configs/model_config.py | 2 +- pilot/server/webserver.py | 24 ++++++++++++++++-------- pilot/vector_store/extract_tovec.py | 10 ++++------ 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/pilot/app.py b/pilot/app.py index b3f8d6ab1..1cbcd79b0 100644 --- a/pilot/app.py +++ b/pilot/app.py @@ -40,8 +40,8 @@ def get_answer(q): return response.response def get_similar(q): - from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st - docsearch = knownledge_tovec_st("./datasets/plan.md") + from pilot.vector_store.extract_tovec import knownledge_tovec, load_knownledge_from_doc + docsearch = load_knownledge_from_doc() docs = docsearch.similarity_search_with_score(q, k=1) for doc in docs: diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index a9436276f..4527fafe0 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -8,7 +8,7 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi MODEL_PATH = os.path.join(ROOT_PATH, "models") VECTORE_PATH = os.path.join(ROOT_PATH, "vector_store") LOGDIR = os.path.join(ROOT_PATH, "logs") -DATASETS_DIR = os.path.join(ROOT_PATH, "datasets") +DATASETS_DIR = os.path.join(ROOT_PATH, "pilot/datasets") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" LLM_MODEL_CONFIG = { diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 8154ea99d..b7a0ef816 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -12,9 +12,9 @@ import requests from urllib.parse import urljoin from pilot.configs.model_config import DB_SETTINGS from pilot.connections.mysql_conn import MySQLOperator -from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc +from pilot.vector_store.extract_tovec import get_vector_storelist, load_knownledge_from_doc, knownledge_tovec_st -from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL +from pilot.configs.model_config import LOGDIR, VICUNA_MODEL_SERVER, LLM_MODEL, DATASETS_DIR from pilot.conversation import ( default_conversation, @@ -50,7 +50,7 @@ priority = { def get_simlar(q): - docsearch = load_knownledge_from_doc() + docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docs = docsearch.similarity_search_with_score(q, k=1) contents = [dc.page_content for dc, _ in docs] @@ -171,12 +171,20 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques # prompt 中添加上下文提示 if db_selector: new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) + new_state.append_message(new_state.roles[1], None) + state = new_state + else: + new_state.append_message(new_state.roles[0], query) + new_state.append_message(new_state.roles[1], None) + state = new_state - new_state.append_message(new_state.roles[1], None) - state = new_state - - if not db_selector: - state.append_message(new_state.roles[0], get_simlar(query) + query) + try: + if not db_selector: + sim_q = get_simlar(query) + print("********vector similar info*************: ", sim_q) + state.append_message(new_state.roles[0], sim_q + query) + except Exception as e: + print(e) prompt = state.get_prompt() diff --git a/pilot/vector_store/extract_tovec.py b/pilot/vector_store/extract_tovec.py index 223ff90c8..8badf6fed 100644 --- a/pilot/vector_store/extract_tovec.py +++ b/pilot/vector_store/extract_tovec.py @@ -50,17 +50,15 @@ def load_knownledge_from_doc(): from pilot.configs.model_config import LLM_MODEL_CONFIG embeddings = HuggingFaceEmbeddings(model_name=LLM_MODEL_CONFIG["sentence-transforms"]) - docs = [] files = os.listdir(DATASETS_DIR) for file in files: if not os.path.isdir(file): - with open(file, "r") as f: - doc = f.read() - docs.append(docs) + filename = os.path.join(DATASETS_DIR, file) + with open(filename, "r") as f: + knownledge = f.read() - print(doc) text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_owerlap=0) - texts = text_splitter.split_text("\n".join(docs)) + texts = text_splitter.split_text(knownledge) docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))], persist_directory=os.path.join(VECTORE_PATH, ".vectore")) return docsearch From 5fbd83e573a3a8253b9d4416841968556a07b8d0 Mon Sep 17 00:00:00 2001 From: csunny Date: Sat, 6 May 2023 01:13:47 +0800 Subject: [PATCH 2/2] fix --- pilot/app.py | 6 +++--- pilot/server/webserver.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pilot/app.py b/pilot/app.py index 1cbcd79b0..7cb3aad7f 100644 --- a/pilot/app.py +++ b/pilot/app.py @@ -40,13 +40,13 @@ def get_answer(q): return response.response def get_similar(q): - from pilot.vector_store.extract_tovec import knownledge_tovec, load_knownledge_from_doc - docsearch = load_knownledge_from_doc() + from pilot.vector_store.extract_tovec import knownledge_tovec, knownledge_tovec_st + docsearch = knownledge_tovec_st("./datasets/plan.md") docs = docsearch.similarity_search_with_score(q, k=1) for doc in docs: dc, s = doc - print(dc.page_content) + print(s) yield dc.page_content if __name__ == "__main__": diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index b7a0ef816..5aa3c6526 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -160,7 +160,7 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return - query = state.messages[-2][1] + if len(state.messages) == state.offset + 2: # 第一轮对话需要加入提示Prompt @@ -168,6 +168,8 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques new_state = conv_templates[template_name].copy() new_state.conv_id = uuid.uuid4().hex + query = state.messages[-2][1] + # prompt 中添加上下文提示 if db_selector: new_state.append_message(new_state.roles[0], gen_sqlgen_conversation(dbname) + query) @@ -178,13 +180,13 @@ def http_bot(state, db_selector, temperature, max_new_tokens, request: gr.Reques new_state.append_message(new_state.roles[1], None) state = new_state - try: - if not db_selector: - sim_q = get_simlar(query) - print("********vector similar info*************: ", sim_q) - state.append_message(new_state.roles[0], sim_q + query) - except Exception as e: - print(e) + # try: + # if not db_selector: + # sim_q = get_simlar(query) + # print("********vector similar info*************: ", sim_q) + # state.append_message(new_state.roles[0], sim_q + query) + # except Exception as e: + # print(e) prompt = state.get_prompt()