From d3567fb984672c68dfd2962f75afd1ea148289ff Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Thu, 25 May 2023 13:01:30 +0800 Subject: [PATCH] update:default knowledge init --- pilot/server/vectordb_qa.py | 25 +++++++++++ pilot/server/webserver.py | 44 +++++++------------ pilot/source_embedding/knowledge_embedding.py | 2 +- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py index 6bf0b4688..a3fe9be4a 100644 --- a/pilot/server/vectordb_qa.py +++ b/pilot/server/vectordb_qa.py @@ -5,6 +5,7 @@ from langchain.prompts import PromptTemplate from pilot.configs.model_config import VECTOR_SEARCH_TOP_K from pilot.conversation import conv_qa_prompt_template +from pilot.logs import logger from pilot.model.vicuna_llm import VicunaLLM from pilot.vector_store.file_loader import KnownLedge2Vector @@ -28,3 +29,27 @@ class KnownLedgeBaseQA: context = [d.page_content for d in docs] result = prompt.format(context="\n".join(context), question=query) return result + + @staticmethod + def build_knowledge_prompt(query, docs, state): + prompt_template = PromptTemplate( + template=conv_qa_prompt_template, input_variables=["context", "question"] + ) + context = [d.page_content for d in docs] + result = prompt_template.format(context="\n".join(context), question=query) + state.messages[-2][1] = result + prompt = state.get_prompt() + + if len(prompt) > 4000: + logger.info("prompt length greater than 4000, rebuild") + context = context[:2000] + prompt_template = PromptTemplate( + template=conv_qa_prompt_template, + input_variables=["context", "question"], + ) + result = prompt_template.format(context="\n".join(context), question=query) + state.messages[-2][1] = result + prompt = state.get_prompt() + print("new prompt length:" + str(len(prompt))) + + return prompt \ No newline at end of file diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 04737f6f2..3a2e9f891 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -13,7 +13,8 @@ from urllib.parse import urljoin import gradio as gr import requests -from langchain import PromptTemplate + +from pilot.server.vectordb_qa import KnownLedgeBaseQA ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -40,16 +41,13 @@ from pilot.conversation import ( ) from pilot.plugins import scan_plugins from pilot.prompts.auto_mode_prompt import AutoModePrompt -from pilot.prompts.generator import PromptGenerator from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot -from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding from pilot.utils import build_logger, server_error_msg from pilot.vector_store.extract_tovec import ( get_vector_storelist, knownledge_tovec_st, - load_knownledge_from_doc, ) logger = build_logger("webserver", LOGDIR + "webserver.log") @@ -263,10 +261,19 @@ def http_bot( prompt = state.get_prompt() skip_echo_len = len(prompt.replace("", " ")) + 1 if mode == conversation_types["default_knownledge"] and not db_selector: + vector_store_config = { + "vector_store_name": "default", + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, + } + knowledge_embedding_client = KnowledgeEmbedding( + file_path="", + model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config=vector_store_config, + ) query = state.messages[-2][1] - knqa = KnownLedgeBaseQA() - state.messages[-2][1] = knqa.get_similar_answer(query) - prompt = state.get_prompt() + docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) + prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 @@ -285,26 +292,7 @@ def http_bot( ) query = state.messages[-2][1] docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K) - context = [d.page_content for d in docs] - prompt_template = PromptTemplate( - template=conv_qa_prompt_template, input_variables=["context", "question"] - ) - result = prompt_template.format(context="\n".join(context), question=query) - state.messages[-2][1] = result - prompt = state.get_prompt() - print("prompt length:" + str(len(prompt))) - - if len(prompt) > 4000: - logger.info("prompt length greater than 4000, rebuild") - context = context[:2000] - prompt_template = PromptTemplate( - template=conv_qa_prompt_template, - input_variables=["context", "question"], - ) - result = prompt_template.format(context="\n".join(context), question=query) - state.messages[-2][1] = result - prompt = state.get_prompt() - print("new prompt length:" + str(len(prompt))) + prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state) state.messages[-2][1] = query skip_echo_len = len(prompt.replace("", " ")) + 1 @@ -697,7 +685,7 @@ if __name__ == "__main__": # 配置初始化 cfg = Config() - dbs = get_database_list() + # dbs = get_database_list() cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # 加载插件可执行命令 diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 33f35f826..316667dee 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -2,7 +2,7 @@ import os import markdown from bs4 import BeautifulSoup -from langchain.document_loaders import PyPDFLoader, TextLoader, markdown +from langchain.document_loaders import PyPDFLoader, TextLoader from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config