diff --git a/pilot/conversation.py b/pilot/conversation.py index 253dede3a..073f25f24 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -231,8 +231,8 @@ auto_dbgpt_without_shot = Conversation( sep2="", ) -conv_qa_prompt_template = """ 基于以下已知的信息, 专业、详细的回答用户的问题, - 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题", 但是你可以给出一些与问题相关答案的建议。 +conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回答用户的问题, + 如果无法从提供的恶内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 已知内容: {context} 问题: diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index dce0157a0..95fdbaf68 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -11,6 +11,9 @@ import gradio as gr import datetime import requests from urllib.parse import urljoin + +from langchain import PromptTemplate + from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql import MySQLOperator @@ -32,7 +35,7 @@ from pilot.conversation import ( conv_templates, conversation_types, conversation_sql_mode, - SeparatorStyle + SeparatorStyle, conv_qa_prompt_template ) from pilot.utils import ( @@ -57,6 +60,8 @@ models = [] dbs = [] vs_list = ["新建知识库"] + get_vector_storelist() autogpt = False +vector_store_client = None +vector_store_name = {"vs_name": ""} priority = { "vicuna-13b": "aaa" @@ -217,16 +222,28 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re state.messages[0][1] = "" state.messages[-2][1] = follow_up_prompt - if mode == conversation_types["default_knownledge"] and not db_selector: query = state.messages[-2][1] knqa = KnownLedgeBaseQA() state.messages[-2][1] = knqa.get_similar_answer(query) - - prompt = state.get_prompt() - - skip_echo_len = len(prompt.replace("", " ")) + 1 + if mode == conversation_types["custome"] and not db_selector: + persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb") + print("向量数据库持久化地址: ", persist_dir) + knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["sentence-transforms"], vector_store_config={"vector_store_name": vector_store_name["vs_name"], + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) + query = state.messages[-2][1] + docs = knowledge_embedding_client.similar_search(query, 1) + context = [d.page_content for d in docs] + prompt_template = PromptTemplate( + template=conv_qa_prompt_template, + input_variables=["context", "question"] + ) + result = prompt_template.format(context="\n".join(context), question=query) + state.messages[-2][1] = result + prompt = state.get_prompt() + state.messages[-2][1] = query + skip_echo_len = len(prompt.replace("", " ")) + 1 # Make requests payload = { @@ -437,8 +454,9 @@ def build_single_model_ui(): load_file_button = gr.Button("上传并加载到知识库") with gr.Tab("上传文件夹"): - folder_files = gr.File(label="添加文件", - file_count="directory", + folder_files = gr.File(label="添加文件夹", + accept_multiple_files=True, + file_count="directory", show_label=False) load_folder_button = gr.Button("上传并加载到知识库") @@ -483,15 +501,17 @@ def build_single_model_ui(): [state, mode, sql_mode, db_selector, temperature, max_output_tokens], [state, chatbot] + btn_list ) + vs_add.click(fn=save_vs_name, show_progress=True, + inputs=[vs_name], + outputs=[vs_name]) load_file_button.click(fn=knowledge_embedding_store, show_progress=True, inputs=[vs_name, files], outputs=[vs_name]) - # load_folder_button.click(get_vector_store, - # show_progress=True, - # inputs=[vs_name, folder_files, 100 , chatbot, vs_add, - # vs_add], - # outputs=["db-out", folder_files, chatbot]) + load_folder_button.click(fn=knowledge_embedding_store, + show_progress=True, + inputs=[vs_name, folder_files], + outputs=[vs_name]) return state, chatbot, textbox, send_btn, button_row, parameter_row @@ -531,6 +551,10 @@ def build_webdemo(): return demo +def save_vs_name(vs_name): + vector_store_name["vs_name"] = vs_name + return vs_name + def knowledge_embedding_store(vs_id, files): # vs_path = os.path.join(VS_ROOT_PATH, vs_id) if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)): @@ -538,10 +562,15 @@ def knowledge_embedding_store(vs_id, files): for file in files: filename = os.path.split(file.name)[-1] shutil.move(file.name, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename)) + knowledge_embedding_client = KnowledgeEmbedding( + file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), + model_name=LLM_MODEL_CONFIG["sentence-transforms"], + vector_store_config={ + "vector_store_name": vector_store_name["vs_name"], + "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) + knowledge_embedding_client.knowledge_embedding() + - knowledge_embedding = KnowledgeEmbedding.knowledge_embedding(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), LLM_MODEL_CONFIG["sentence-transforms"], {"vector_store_name": vs_id, - "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) - knowledge_embedding.source_embedding() logger.info("knowledge embedding success") return os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, vs_id + ".vectordb") diff --git a/pilot/source_embedding/csv_embedding.py b/pilot/source_embedding/csv_embedding.py index db73ae7e5..2f3b7ed06 100644 --- a/pilot/source_embedding/csv_embedding.py +++ b/pilot/source_embedding/csv_embedding.py @@ -10,6 +10,7 @@ class CSVEmbedding(SourceEmbedding): def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None): """Initialize with csv path.""" + super().__init__(file_path, model_name, vector_store_config) self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index adc8c430f..a9e4d4e4e 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -4,17 +4,31 @@ from pilot.source_embedding.pdf_embedding import PDFEmbedding class KnowledgeEmbedding: - @staticmethod - def knowledge_embedding(file_path:str, model_name, vector_store_config): - if file_path.endswith(".pdf"): - embedding = PDFEmbedding(file_path=file_path, model_name=model_name, - vector_store_config=vector_store_config) - elif file_path.endswith(".md"): - embedding = MarkdownEmbedding(file_path=file_path, model_name=model_name, - vector_store_config=vector_store_config) + def __init__(self, file_path, model_name, vector_store_config): + """Initialize with Loader url, model_name, vector_store_config""" + self.file_path = file_path + self.model_name = model_name + self.vector_store_config = vector_store_config + self.vector_store_type = "default" + self.knowledge_embedding_client = self.init_knowledge_embedding() - elif file_path.endswith(".csv"): - embedding = CSVEmbedding(file_path=file_path, model_name=model_name, - vector_store_config=vector_store_config) + def knowledge_embedding(self): + self.knowledge_embedding_client.source_embedding() - return embedding \ No newline at end of file + def init_knowledge_embedding(self): + if self.file_path.endswith(".pdf"): + embedding = PDFEmbedding(file_path=self.file_path, model_name=self.model_name, + vector_store_config=self.vector_store_config) + elif self.file_path.endswith(".md"): + embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config) + + elif self.file_path.endswith(".csv"): + embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name, + vector_store_config=self.vector_store_config) + elif self.vector_store_type == "default": + embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config) + + return embedding + + def similar_search(self, text, topk): + return self.knowledge_embedding_client.similar_search(text, topk) \ No newline at end of file diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index 66f8c5aa5..622011006 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -15,6 +15,7 @@ class MarkdownEmbedding(SourceEmbedding): def __init__(self, file_path, model_name, vector_store_config): """Initialize with markdown path.""" + super().__init__(file_path, model_name, vector_store_config) self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index 71a310bc3..617190fe5 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -13,9 +13,12 @@ class PDFEmbedding(SourceEmbedding): def __init__(self, file_path, model_name, vector_store_config): """Initialize with pdf path.""" + super().__init__(file_path, model_name, vector_store_config) self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config + # SourceEmbedding(file_path =file_path, ); + SourceEmbedding(file_path, model_name, vector_store_config) @register def read(self): diff --git a/pilot/source_embedding/search_milvus.py b/pilot/source_embedding/search_milvus.py index 18f93d1d3..ec0aa6813 100644 --- a/pilot/source_embedding/search_milvus.py +++ b/pilot/source_embedding/search_milvus.py @@ -50,7 +50,7 @@ # # # text_embeddings = Text2Vectors() # mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"}) -# +# # mivuls.insert(["textc","tezt2"]) # print("success") # ct diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index b76f5ad46..656d24eaf 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -22,12 +22,16 @@ class SourceEmbedding(ABC): Implementations should implement the method """ - def __init__(self, yuque_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None): - """Initialize with YuqueLoader url, model_name, vector_store_config""" - self.yuque_path = yuque_path + def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None): + """Initialize with Loader url, model_name, vector_store_config""" + self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config self.embedding_args = embedding_args + self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) + persist_dir = os.path.join(self.vector_store_config["vector_store_path"], + self.vector_store_config["vector_store_name"] + ".vectordb") + self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) @abstractmethod @register @@ -50,18 +54,16 @@ class SourceEmbedding(ABC): @register def index_to_store(self, docs): """index to vector store""" - embeddings = HuggingFaceEmbeddings(model_name=self.model_name) - persist_dir = os.path.join(self.vector_store_config["vector_store_path"], self.vector_store_config["vector_store_name"] + ".vectordb") - self.vector_store = Chroma.from_documents(docs, embeddings, persist_directory=persist_dir) + self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir) self.vector_store.persist() @register def similar_search(self, doc, topk): """vector store similarity_search""" - return self.vector_store.similarity_search(doc, topk) + return self.vector_store_client.similarity_search(doc, topk) def source_embedding(self): if 'read' in registered_methods: