From 2656a8030ecf24c20ef582a302bb865a791a801a Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Thu, 18 May 2023 20:03:24 +0800 Subject: [PATCH] feature:knowledge embedding update --- .gitignore | 3 +- pilot/conversation.py | 7 ++ pilot/server/webserver.py | 13 +-- pilot/source_embedding/knowledge_embedding.py | 82 ++++++++++++++++++- pilot/source_embedding/markdown_embedding.py | 27 +++++- pilot/source_embedding/pdf_embedding.py | 4 +- pilot/source_embedding/source_embedding.py | 12 +++ pilot/source_embedding/url_embedding.py | 6 +- pilot/vector_store/file_loader.py | 6 +- 9 files changed, 143 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index cb21ee557..5043f7db0 100644 --- a/.gitignore +++ b/.gitignore @@ -136,4 +136,5 @@ dmypy.json .DS_Store logs nltk_data -.vectordb \ No newline at end of file +.vectordb +pilot/data/ \ No newline at end of file diff --git a/pilot/conversation.py b/pilot/conversation.py index 7054fb453..7f526fb89 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -247,6 +247,13 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回 {question} """ +# conv_qa_prompt_template = """ Please provide the known information so that I can professionally and briefly answer the user's question. If the answer cannot be obtained from the provided content, +# please say: "The information provided in the knowledge base is insufficient to answer this question." Fabrication is prohibited.。 +# known information: +# {context} +# question: +# {question} +# """ default_conversation = conv_one_shot conversation_sql_mode ={ diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 289ce7f32..e6ba19160 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -19,7 +19,7 @@ from langchain import PromptTemplate ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) -from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG +from pilot.configs.model_config import DB_SETTINGS, KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, TOP_RETURN_SIZE from pilot.server.vectordb_qa import KnownLedgeBaseQA from pilot.connections.mysql import MySQLOperator from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding @@ -256,11 +256,13 @@ def http_bot(state, mode, sql_mode, db_selector, temperature, max_new_tokens, re if mode == conversation_types["custome"] and not db_selector: persist_dir = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vector_store_name["vs_name"] + ".vectordb") - print("向量数据库持久化地址: ", persist_dir) - knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config={"vector_store_name": vector_store_name["vs_name"], + print("vector store path: ", persist_dir) + knowledge_embedding_client = KnowledgeEmbedding(file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, + vector_store_config={"vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) query = state.messages[-2][1] - docs = knowledge_embedding_client.similar_search(query, 10) + docs = knowledge_embedding_client.similar_search(query, TOP_RETURN_SIZE) context = [d.page_content for d in docs] prompt_template = PromptTemplate( template=conv_qa_prompt_template, @@ -600,6 +602,7 @@ def knowledge_embedding_store(vs_id, files): knowledge_embedding_client = KnowledgeEmbedding( file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), model_name=LLM_MODEL_CONFIG["text2vec"], + local_persist=False, vector_store_config={ "vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH}) @@ -624,7 +627,7 @@ if __name__ == "__main__": # 配置初始化 cfg = Config() - dbs = get_database_list() + # dbs = get_database_list() cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index a9e4d4e4e..594723b6e 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,20 +1,35 @@ +import os + +from bs4 import BeautifulSoup +from langchain.document_loaders import PyPDFLoader, TextLoader, markdown +from langchain.embeddings import HuggingFaceEmbeddings +from langchain.vectorstores import Chroma +from pilot.configs.model_config import DATASETS_DIR +from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding +import markdown class KnowledgeEmbedding: - def __init__(self, file_path, model_name, vector_store_config): + def __init__(self, file_path, model_name, vector_store_config, local_persist=True): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config self.vector_store_type = "default" - self.knowledge_embedding_client = self.init_knowledge_embedding() + self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) + self.local_persist = local_persist + if not self.local_persist: + self.knowledge_embedding_client = self.init_knowledge_embedding() def knowledge_embedding(self): self.knowledge_embedding_client.source_embedding() + def knowledge_embedding_batch(self): + self.knowledge_embedding_client.batch_embedding() + def init_knowledge_embedding(self): if self.file_path.endswith(".pdf"): embedding = PDFEmbedding(file_path=self.file_path, model_name=self.model_name, @@ -31,4 +46,65 @@ class KnowledgeEmbedding: return embedding def similar_search(self, text, topk): - return self.knowledge_embedding_client.similar_search(text, topk) \ No newline at end of file + return self.knowledge_embedding_client.similar_search(text, topk) + + def knowledge_persist_initialization(self, append_mode): + vector_name = self.vector_store_config["vector_store_name"] + persist_dir = os.path.join(self.vector_store_config["vector_store_path"], vector_name + ".vectordb") + print("vector db path: ", persist_dir) + if os.path.exists(persist_dir): + if append_mode: + print("append knowledge return vector store") + new_documents = self._load_knownlege(self.file_path) + vector_store = Chroma.from_documents(documents=new_documents, + embedding=self.embeddings, + persist_directory=persist_dir) + else: + print("directly return vector store") + vector_store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings) + else: + print(vector_name + "is new vector store, knowledge begin load...") + documents = self._load_knownlege(self.file_path) + vector_store = Chroma.from_documents(documents=documents, + embedding=self.embeddings, + persist_directory=persist_dir) + vector_store.persist() + return vector_store + + def _load_knownlege(self, path): + docments = [] + for root, _, files in os.walk(path, topdown=False): + for file in files: + filename = os.path.join(root, file) + docs = self._load_file(filename) + new_docs = [] + for doc in docs: + doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")} + print("doc is embedding...", doc.metadata) + new_docs.append(doc) + docments += new_docs + return docments + + def _load_file(self, filename): + if filename.lower().endswith(".md"): + loader = TextLoader(filename) + text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=100) + docs = loader.load_and_split(text_splitter) + i = 0 + for d in docs: + content = markdown.markdown(d.page_content) + soup = BeautifulSoup(content, 'html.parser') + for tag in soup(['!doctype', 'meta', 'i.fa']): + tag.extract() + docs[i].page_content = soup.get_text() + docs[i].page_content = docs[i].page_content.replace("\n", " ") + i += 1 + elif filename.lower().endswith(".pdf"): + loader = PyPDFLoader(filename) + textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=100) + docs = loader.load_and_split(textsplitter) + else: + loader = TextLoader(filename) + text_splitor = CHNDocumentSplitter(sentence_size=100) + docs = loader.load_and_split(text_splitor) + return docs \ No newline at end of file diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index 622011006..fee9504b6 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import os from typing import List from bs4 import BeautifulSoup @@ -8,6 +9,7 @@ from langchain.schema import Document import markdown from pilot.source_embedding import SourceEmbedding, register +from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter class MarkdownEmbedding(SourceEmbedding): @@ -24,7 +26,28 @@ class MarkdownEmbedding(SourceEmbedding): def read(self): """Load from markdown path.""" loader = TextLoader(self.file_path) - return loader.load() + text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=100) + return loader.load_and_split(text_splitter) + + @register + def read_batch(self): + """Load from markdown path.""" + docments = [] + for root, _, files in os.walk(self.file_path, topdown=False): + for file in files: + filename = os.path.join(root, file) + loader = TextLoader(filename) + # text_splitor = CHNDocumentSplitter(chunk_size=1000, chunk_overlap=20, length_function=len) + # docs = loader.load_and_split() + docs = loader.load() + # 更新metadata数据 + new_docs = [] + for doc in docs: + doc.metadata = {"source": doc.metadata["source"].replace(self.file_path, "")} + print("doc is embedding ... ", doc.metadata) + new_docs.append(doc) + docments += new_docs + return docments @register def data_process(self, documents: List[Document]): @@ -35,7 +58,7 @@ class MarkdownEmbedding(SourceEmbedding): for tag in soup(['!doctype', 'meta', 'i.fa']): tag.extract() documents[i].page_content = soup.get_text() - documents[i].page_content = documents[i].page_content.replace(" ", "").replace("\n", " ") + documents[i].page_content = documents[i].page_content.replace("\n", " ") i += 1 return documents diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index 557637c5a..bd0ae3aba 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -6,7 +6,7 @@ from langchain.document_loaders import PyPDFLoader from langchain.schema import Document from pilot.source_embedding import SourceEmbedding, register -from pilot.source_embedding.chinese_text_splitter import ChineseTextSplitter +from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter class PDFEmbedding(SourceEmbedding): @@ -23,7 +23,7 @@ class PDFEmbedding(SourceEmbedding): def read(self): """Load from pdf path.""" loader = PyPDFLoader(self.file_path) - textsplitter = ChineseTextSplitter(pdf=True, sentence_size=100) + textsplitter = CHNDocumentSplitter(pdf=True, sentence_size=100) return loader.load_and_split(textsplitter) @register diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index 656d24eaf..66bc97b6d 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -76,3 +76,15 @@ class SourceEmbedding(ABC): self.text_to_vector(text) if 'index_to_store' in registered_methods: self.index_to_store(text) + + def batch_embedding(self): + if 'read_batch' in registered_methods: + text = self.read_batch() + if 'data_process' in registered_methods: + text = self.data_process(text) + if 'text_split' in registered_methods: + self.text_split(text) + if 'text_to_vector' in registered_methods: + self.text_to_vector(text) + if 'index_to_store' in registered_methods: + self.index_to_store(text) diff --git a/pilot/source_embedding/url_embedding.py b/pilot/source_embedding/url_embedding.py index 5fa29e0d2..68fbdd5e4 100644 --- a/pilot/source_embedding/url_embedding.py +++ b/pilot/source_embedding/url_embedding.py @@ -1,4 +1,7 @@ from typing import List + +from langchain.text_splitter import CharacterTextSplitter + from pilot.source_embedding import SourceEmbedding, register from bs4 import BeautifulSoup @@ -20,7 +23,8 @@ class URLEmbedding(SourceEmbedding): def read(self): """Load from url path.""" loader = WebBaseLoader(web_path=self.file_path) - return loader.load() + text_splitor = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20, length_function=len) + return loader.load_and_split(text_splitor) @register def data_process(self, documents: List[Document]): diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py index 8f668f60e..8703e2e4c 100644 --- a/pilot/vector_store/file_loader.py +++ b/pilot/vector_store/file_loader.py @@ -48,12 +48,12 @@ class KnownLedge2Vector: # vector_store.add_documents(documents=documents) else: documents = self.load_knownlege() - # reinit + # reinit vector_store = Chroma.from_documents(documents=documents, embedding=self.embeddings, persist_directory=persist_dir) vector_store.persist() - return vector_store + return vector_store def load_knownlege(self): docments = [] @@ -61,7 +61,7 @@ class KnownLedge2Vector: for file in files: filename = os.path.join(root, file) docs = self._load_file(filename) - # update metadata. + # update metadata. new_docs = [] for doc in docs: doc.metadata = {"source": doc.metadata["source"].replace(DATASETS_DIR, "")}