From be1a792d3c607cc170d73033f63bbc6885d1f9ae Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Mon, 5 Jun 2023 16:26:19 +0800 Subject: [PATCH] feature:knowledge embedding support file path auto adapt --- pilot/configs/config.py | 2 + pilot/configs/model_config.py | 1 - pilot/data/__init__.py | 0 pilot/scene/chat_knowledge/custom/chat.py | 2 - pilot/scene/chat_knowledge/custom/prompt.py | 10 ++- pilot/scene/chat_knowledge/default/chat.py | 2 - pilot/scene/chat_knowledge/default/prompt.py | 10 ++- pilot/scene/chat_knowledge/url/chat.py | 6 +- pilot/scene/chat_knowledge/url/prompt.py | 17 ++--- pilot/source_embedding/EncodeTextLoader.py | 26 ++++++++ pilot/source_embedding/knowledge_embedding.py | 62 ++++++++----------- pilot/source_embedding/markdown_embedding.py | 4 +- pilot/source_embedding/pdf_embedding.py | 4 +- pilot/source_embedding/word_embedding.py | 38 ++++++++++++ pilot/summary/db_summary_client.py | 8 --- tools/knowlege_init.py | 30 ++++----- 16 files changed, 140 insertions(+), 82 deletions(-) delete mode 100644 pilot/data/__init__.py create mode 100644 pilot/source_embedding/EncodeTextLoader.py create mode 100644 pilot/source_embedding/word_embedding.py diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 3762b43c1..6f6271477 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -148,6 +148,8 @@ class Config(metaclass=Singleton): ### EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") + self.KNOWLEDGE_CHUNK_SIZE = os.getenv("KNOWLEDGE_CHUNK_SIZE", 100) + self.KNOWLEDGE_SEARCH_TOP_SIZE = os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10) ### SUMMARY_CONFIG Configuration self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR") diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 7bbbac361..4b8b85a62 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -34,7 +34,6 @@ LLM_MODEL_CONFIG = { "chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), - "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged" "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "proxyllm": "proxyllm", diff --git a/pilot/data/__init__.py b/pilot/data/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 7600bab79..8fc0f3d82 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -46,9 +46,7 @@ class ChatNewKnowledge(BaseChat): "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( - file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, vector_store_config=vector_store_config, ) diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py index 110250221..a76bb70ba 100644 --- a/pilot/scene/chat_knowledge/custom/prompt.py +++ b/pilot/scene/chat_knowledge/custom/prompt.py @@ -14,13 +14,21 @@ CFG = Config() PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" -_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, +_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题, 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 已知内容: {context} 问题: {question} """ +_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. + known information: + {context} + question: + {question} +""" + +_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 1a482b154..1087ee2c0 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -42,9 +42,7 @@ class ChatDefaultKnowledge(BaseChat): "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( - file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, vector_store_config=vector_store_config, ) diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py index 0526be69b..5dc512898 100644 --- a/pilot/scene/chat_knowledge/default/prompt.py +++ b/pilot/scene/chat_knowledge/default/prompt.py @@ -15,13 +15,21 @@ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelli The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ -_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, +_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题, 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 已知内容: {context} 问题: {question} """ +_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. + known information: + {context} + question: + {question} +""" + +_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index cc8d89d4a..2634dc80d 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -40,15 +40,13 @@ class ChatUrlKnowledge(BaseChat): self.url = url vector_store_config = { "vector_store_name": url, - "text_field": "content", "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( - file_path=url, - file_type="url", model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, vector_store_config=vector_store_config, + file_type="url", + file_path=url, ) # url soruce in vector diff --git a/pilot/scene/chat_knowledge/url/prompt.py b/pilot/scene/chat_knowledge/url/prompt.py index 38d5dfe35..e887cc53f 100644 --- a/pilot/scene/chat_knowledge/url/prompt.py +++ b/pilot/scene/chat_knowledge/url/prompt.py @@ -14,20 +14,21 @@ CFG = Config() PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ - -# _DEFAULT_TEMPLATE = """ Based on the known information, provide professional and concise answers to the user's questions. If the answer cannot be obtained from the provided content, please say: 'The information provided in the knowledge base is not sufficient to answer this question.' Fabrication is prohibited.。 -# known information: -# {context} -# question: -# {question} -# """ -_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, +_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题, 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 已知内容: {context} 问题: {question} """ +_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. + known information: + {context} + question: + {question} +""" + +_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/source_embedding/EncodeTextLoader.py b/pilot/source_embedding/EncodeTextLoader.py new file mode 100644 index 000000000..7f3ba7e5a --- /dev/null +++ b/pilot/source_embedding/EncodeTextLoader.py @@ -0,0 +1,26 @@ +from typing import List, Optional +import chardet + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader + + +class EncodeTextLoader(BaseLoader): + """Load text files.""" + + def __init__(self, file_path: str, encoding: Optional[str] = None): + """Initialize with file path.""" + self.file_path = file_path + self.encoding = encoding + + def load(self) -> List[Document]: + """Load from file path.""" + with open(self.file_path, 'rb') as f: + raw_text = f.read() + result = chardet.detect(raw_text) + if result['encoding'] is None: + text = raw_text.decode('utf-8') + else: + text = raw_text.decode(result['encoding']) + metadata = {"source": self.file_path} + return [Document(page_content=text, metadata=metadata)] diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index f58742ee9..97d9f590b 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,4 +1,5 @@ import os +from typing import Optional import markdown from bs4 import BeautifulSoup @@ -12,19 +13,28 @@ from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding from pilot.source_embedding.url_embedding import URLEmbedding +from pilot.source_embedding.word_embedding import WordEmbedding from pilot.vector_store.connector import VectorStoreConnector CFG = Config() +KnowledgeEmbeddingType = { + ".txt": (MarkdownEmbedding, {}), + ".md": (MarkdownEmbedding,{}), + ".pdf": (PDFEmbedding, {}), + ".doc": (WordEmbedding, {}), + ".docx": (WordEmbedding, {}), + ".csv": (CSVEmbedding, {}), +} class KnowledgeEmbedding: def __init__( self, - file_path, model_name, vector_store_config, - local_persist=True, - file_type="default", + file_type: Optional[str] = "default", + file_path: Optional[str] = None, + ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path @@ -33,11 +43,9 @@ class KnowledgeEmbedding: self.file_type = file_type self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.vector_store_config["embeddings"] = self.embeddings - self.local_persist = local_persist - if not self.local_persist: - self.knowledge_embedding_client = self.init_knowledge_embedding() def knowledge_embedding(self): + self.knowledge_embedding_client = self.init_knowledge_embedding() self.knowledge_embedding_client.source_embedding() def knowledge_embedding_batch(self): @@ -50,40 +58,24 @@ class KnowledgeEmbedding: model_name=self.model_name, vector_store_config=self.vector_store_config, ) - elif self.file_path.endswith(".pdf"): - embedding = PDFEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) - elif self.file_path.endswith(".md"): - embedding = MarkdownEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) - - elif self.file_path.endswith(".csv"): - embedding = CSVEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) - - elif self.file_type == "default": - embedding = MarkdownEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) - + return embedding + extension = "." + self.file_path.rsplit(".", 1)[-1] + if extension in KnowledgeEmbeddingType: + knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] + embedding = knowledge_class(self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config, **knowledge_args) + return embedding + raise ValueError(f"Unsupported knowledge file type '{extension}'") return embedding def similar_search(self, text, topk): - return self.knowledge_embedding_client.similar_search(text, topk) + vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config) + return vector_client.similar_search(text, topk) def vector_exist(self): - return self.knowledge_embedding_client.vector_name_exist() + vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) + return vector_client.vector_name_exists() def knowledge_persist_initialization(self, append_mode): documents = self._load_knownlege(self.file_path) diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index 3db6cdbf5..da974c366 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -10,6 +10,7 @@ from langchain.schema import Document from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.source_embedding import SourceEmbedding, register +from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter @@ -22,11 +23,12 @@ class MarkdownEmbedding(SourceEmbedding): self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config + # self.encoding = encoding @register def read(self): """Load from markdown path.""" - loader = TextLoader(self.file_path) + loader = EncodeTextLoader(self.file_path) text_splitter = CHNDocumentSplitter( pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE ) diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index de1767c51..b3f64b788 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -13,13 +13,13 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter class PDFEmbedding(SourceEmbedding): """pdf embedding for read pdf document.""" - def __init__(self, file_path, model_name, vector_store_config): + def __init__(self, file_path, model_name, vector_store_config, encoding): """Initialize with pdf path.""" super().__init__(file_path, model_name, vector_store_config) self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config - + self.encoding = encoding @register def read(self): """Load from pdf path.""" diff --git a/pilot/source_embedding/word_embedding.py b/pilot/source_embedding/word_embedding.py new file mode 100644 index 000000000..5dd2f0199 --- /dev/null +++ b/pilot/source_embedding/word_embedding.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import List + +from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader +from langchain.schema import Document + +from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE +from pilot.source_embedding import SourceEmbedding, register +from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter + + +class WordEmbedding(SourceEmbedding): + """word embedding for read word document.""" + + def __init__(self, file_path, model_name, vector_store_config): + """Initialize with word path.""" + super().__init__(file_path, model_name, vector_store_config) + self.file_path = file_path + self.model_name = model_name + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from word path.""" + loader = UnstructuredWordDocumentLoader(self.file_path) + textsplitter = CHNDocumentSplitter( + pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + ) + return loader.load_and_split(textsplitter) + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 91805ddd4..c5bfcc718 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -74,17 +74,11 @@ class DBSummaryClient: @staticmethod def get_similar_tables(dbname, query, topk): """get user query related tables info""" - embeddings = HuggingFaceEmbeddings( - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] - ) vector_store_config = { "vector_store_name": dbname + "_profile", - "embeddings": embeddings, } knowledge_embedding_client = KnowledgeEmbedding( - file_path="", model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - local_persist=False, vector_store_config=vector_store_config, ) if CFG.SUMMARY_CONFIG == "FAST": @@ -105,12 +99,10 @@ class DBSummaryClient: for table in related_tables: vector_store_config = { "vector_store_name": table + "_ts", - "embeddings": embeddings, } knowledge_embedding_client = KnowledgeEmbedding( file_path="", model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - local_persist=False, vector_store_config=vector_store_config, ) table_summery = knowledge_embedding_client.similar_search(query, 1) diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index 03c9633d3..018cd3000 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -19,36 +19,32 @@ CFG = Config() class LocalKnowledgeInit: embeddings: object = None - model_name = LLM_MODEL_CONFIG["text2vec"] top_k: int = VECTOR_SEARCH_TOP_K def __init__(self, vector_store_config) -> None: self.vector_store_config = vector_store_config + self.model_name = LLM_MODEL_CONFIG["text2vec"] def knowledge_persist(self, file_path, append_mode): """knowledge persist""" - kv = KnowledgeEmbedding( - file_path=file_path, - model_name=LLM_MODEL_CONFIG["text2vec"], - vector_store_config=self.vector_store_config, - ) - vector_store = kv.knowledge_persist_initialization(append_mode) - return vector_store + for root, _, files in os.walk(file_path, topdown=False): + for file in files: + filename = os.path.join(root, file) + # docs = self._load_file(filename) + ke = KnowledgeEmbedding( + file_path=filename, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) + client = ke.init_knowledge_embedding() + client.source_embedding() - def query(self, q): - """Query similar doc from Vector""" - vector_store = self.init_vector_store() - docs = vector_store.similarity_search_with_score(q, k=self.top_k) - for doc in docs: - dc, s = doc - yield s, dc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--vector_name", type=str, default="default") parser.add_argument("--append", type=bool, default=False) - parser.add_argument("--store_type", type=str, default="Chroma") args = parser.parse_args() vector_name = args.vector_name append_mode = args.append @@ -56,5 +52,5 @@ if __name__ == "__main__": vector_store_config = {"vector_store_name": vector_name} print(vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config) - vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) + kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) print("your knowledge embedding success...")