From be1a792d3c607cc170d73033f63bbc6885d1f9ae Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Mon, 5 Jun 2023 16:26:19 +0800 Subject: [PATCH 1/4] feature:knowledge embedding support file path auto adapt --- pilot/configs/config.py | 2 + pilot/configs/model_config.py | 1 - pilot/data/__init__.py | 0 pilot/scene/chat_knowledge/custom/chat.py | 2 - pilot/scene/chat_knowledge/custom/prompt.py | 10 ++- pilot/scene/chat_knowledge/default/chat.py | 2 - pilot/scene/chat_knowledge/default/prompt.py | 10 ++- pilot/scene/chat_knowledge/url/chat.py | 6 +- pilot/scene/chat_knowledge/url/prompt.py | 17 ++--- pilot/source_embedding/EncodeTextLoader.py | 26 ++++++++ pilot/source_embedding/knowledge_embedding.py | 62 ++++++++----------- pilot/source_embedding/markdown_embedding.py | 4 +- pilot/source_embedding/pdf_embedding.py | 4 +- pilot/source_embedding/word_embedding.py | 38 ++++++++++++ pilot/summary/db_summary_client.py | 8 --- tools/knowlege_init.py | 30 ++++----- 16 files changed, 140 insertions(+), 82 deletions(-) delete mode 100644 pilot/data/__init__.py create mode 100644 pilot/source_embedding/EncodeTextLoader.py create mode 100644 pilot/source_embedding/word_embedding.py diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 3762b43c1..6f6271477 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -148,6 +148,8 @@ class Config(metaclass=Singleton): ### EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") + self.KNOWLEDGE_CHUNK_SIZE = os.getenv("KNOWLEDGE_CHUNK_SIZE", 100) + self.KNOWLEDGE_SEARCH_TOP_SIZE = os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10) ### SUMMARY_CONFIG Configuration self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR") diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 7bbbac361..4b8b85a62 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -34,7 +34,6 @@ LLM_MODEL_CONFIG = { "chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"), "chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"), "text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"), - "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged" "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "proxyllm": "proxyllm", diff --git a/pilot/data/__init__.py b/pilot/data/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 7600bab79..8fc0f3d82 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -46,9 +46,7 @@ class ChatNewKnowledge(BaseChat): "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( - file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, vector_store_config=vector_store_config, ) diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py index 110250221..a76bb70ba 100644 --- a/pilot/scene/chat_knowledge/custom/prompt.py +++ b/pilot/scene/chat_knowledge/custom/prompt.py @@ -14,13 +14,21 @@ CFG = Config() PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" -_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, +_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题, 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 已知内容: {context} 问题: {question} """ +_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. + known information: + {context} + question: + {question} +""" + +_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 1a482b154..1087ee2c0 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -42,9 +42,7 @@ class ChatDefaultKnowledge(BaseChat): "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( - file_path="", model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, vector_store_config=vector_store_config, ) diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py index 0526be69b..5dc512898 100644 --- a/pilot/scene/chat_knowledge/default/prompt.py +++ b/pilot/scene/chat_knowledge/default/prompt.py @@ -15,13 +15,21 @@ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelli The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ -_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, +_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题, 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 已知内容: {context} 问题: {question} """ +_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. + known information: + {context} + question: + {question} +""" + +_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index cc8d89d4a..2634dc80d 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -40,15 +40,13 @@ class ChatUrlKnowledge(BaseChat): self.url = url vector_store_config = { "vector_store_name": url, - "text_field": "content", "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, } self.knowledge_embedding_client = KnowledgeEmbedding( - file_path=url, - file_type="url", model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, vector_store_config=vector_store_config, + file_type="url", + file_path=url, ) # url soruce in vector diff --git a/pilot/scene/chat_knowledge/url/prompt.py b/pilot/scene/chat_knowledge/url/prompt.py index 38d5dfe35..e887cc53f 100644 --- a/pilot/scene/chat_knowledge/url/prompt.py +++ b/pilot/scene/chat_knowledge/url/prompt.py @@ -14,20 +14,21 @@ CFG = Config() PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ - -# _DEFAULT_TEMPLATE = """ Based on the known information, provide professional and concise answers to the user's questions. If the answer cannot be obtained from the provided content, please say: 'The information provided in the knowledge base is not sufficient to answer this question.' Fabrication is prohibited.。 -# known information: -# {context} -# question: -# {question} -# """ -_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题, +_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题, 如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。 已知内容: {context} 问题: {question} """ +_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly. + known information: + {context} + question: + {question} +""" + +_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/source_embedding/EncodeTextLoader.py b/pilot/source_embedding/EncodeTextLoader.py new file mode 100644 index 000000000..7f3ba7e5a --- /dev/null +++ b/pilot/source_embedding/EncodeTextLoader.py @@ -0,0 +1,26 @@ +from typing import List, Optional +import chardet + +from langchain.docstore.document import Document +from langchain.document_loaders.base import BaseLoader + + +class EncodeTextLoader(BaseLoader): + """Load text files.""" + + def __init__(self, file_path: str, encoding: Optional[str] = None): + """Initialize with file path.""" + self.file_path = file_path + self.encoding = encoding + + def load(self) -> List[Document]: + """Load from file path.""" + with open(self.file_path, 'rb') as f: + raw_text = f.read() + result = chardet.detect(raw_text) + if result['encoding'] is None: + text = raw_text.decode('utf-8') + else: + text = raw_text.decode(result['encoding']) + metadata = {"source": self.file_path} + return [Document(page_content=text, metadata=metadata)] diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index f58742ee9..97d9f590b 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,4 +1,5 @@ import os +from typing import Optional import markdown from bs4 import BeautifulSoup @@ -12,19 +13,28 @@ from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding from pilot.source_embedding.url_embedding import URLEmbedding +from pilot.source_embedding.word_embedding import WordEmbedding from pilot.vector_store.connector import VectorStoreConnector CFG = Config() +KnowledgeEmbeddingType = { + ".txt": (MarkdownEmbedding, {}), + ".md": (MarkdownEmbedding,{}), + ".pdf": (PDFEmbedding, {}), + ".doc": (WordEmbedding, {}), + ".docx": (WordEmbedding, {}), + ".csv": (CSVEmbedding, {}), +} class KnowledgeEmbedding: def __init__( self, - file_path, model_name, vector_store_config, - local_persist=True, - file_type="default", + file_type: Optional[str] = "default", + file_path: Optional[str] = None, + ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path @@ -33,11 +43,9 @@ class KnowledgeEmbedding: self.file_type = file_type self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.vector_store_config["embeddings"] = self.embeddings - self.local_persist = local_persist - if not self.local_persist: - self.knowledge_embedding_client = self.init_knowledge_embedding() def knowledge_embedding(self): + self.knowledge_embedding_client = self.init_knowledge_embedding() self.knowledge_embedding_client.source_embedding() def knowledge_embedding_batch(self): @@ -50,40 +58,24 @@ class KnowledgeEmbedding: model_name=self.model_name, vector_store_config=self.vector_store_config, ) - elif self.file_path.endswith(".pdf"): - embedding = PDFEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) - elif self.file_path.endswith(".md"): - embedding = MarkdownEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) - - elif self.file_path.endswith(".csv"): - embedding = CSVEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) - - elif self.file_type == "default": - embedding = MarkdownEmbedding( - file_path=self.file_path, - model_name=self.model_name, - vector_store_config=self.vector_store_config, - ) - + return embedding + extension = "." + self.file_path.rsplit(".", 1)[-1] + if extension in KnowledgeEmbeddingType: + knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] + embedding = knowledge_class(self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config, **knowledge_args) + return embedding + raise ValueError(f"Unsupported knowledge file type '{extension}'") return embedding def similar_search(self, text, topk): - return self.knowledge_embedding_client.similar_search(text, topk) + vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config) + return vector_client.similar_search(text, topk) def vector_exist(self): - return self.knowledge_embedding_client.vector_name_exist() + vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) + return vector_client.vector_name_exists() def knowledge_persist_initialization(self, append_mode): documents = self._load_knownlege(self.file_path) diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index 3db6cdbf5..da974c366 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -10,6 +10,7 @@ from langchain.schema import Document from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE from pilot.source_embedding import SourceEmbedding, register +from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter @@ -22,11 +23,12 @@ class MarkdownEmbedding(SourceEmbedding): self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config + # self.encoding = encoding @register def read(self): """Load from markdown path.""" - loader = TextLoader(self.file_path) + loader = EncodeTextLoader(self.file_path) text_splitter = CHNDocumentSplitter( pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE ) diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index de1767c51..b3f64b788 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -13,13 +13,13 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter class PDFEmbedding(SourceEmbedding): """pdf embedding for read pdf document.""" - def __init__(self, file_path, model_name, vector_store_config): + def __init__(self, file_path, model_name, vector_store_config, encoding): """Initialize with pdf path.""" super().__init__(file_path, model_name, vector_store_config) self.file_path = file_path self.model_name = model_name self.vector_store_config = vector_store_config - + self.encoding = encoding @register def read(self): """Load from pdf path.""" diff --git a/pilot/source_embedding/word_embedding.py b/pilot/source_embedding/word_embedding.py new file mode 100644 index 000000000..5dd2f0199 --- /dev/null +++ b/pilot/source_embedding/word_embedding.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import List + +from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader +from langchain.schema import Document + +from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE +from pilot.source_embedding import SourceEmbedding, register +from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter + + +class WordEmbedding(SourceEmbedding): + """word embedding for read word document.""" + + def __init__(self, file_path, model_name, vector_store_config): + """Initialize with word path.""" + super().__init__(file_path, model_name, vector_store_config) + self.file_path = file_path + self.model_name = model_name + self.vector_store_config = vector_store_config + + @register + def read(self): + """Load from word path.""" + loader = UnstructuredWordDocumentLoader(self.file_path) + textsplitter = CHNDocumentSplitter( + pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + ) + return loader.load_and_split(textsplitter) + + @register + def data_process(self, documents: List[Document]): + i = 0 + for d in documents: + documents[i].page_content = d.page_content.replace("\n", "") + i += 1 + return documents diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 91805ddd4..c5bfcc718 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -74,17 +74,11 @@ class DBSummaryClient: @staticmethod def get_similar_tables(dbname, query, topk): """get user query related tables info""" - embeddings = HuggingFaceEmbeddings( - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] - ) vector_store_config = { "vector_store_name": dbname + "_profile", - "embeddings": embeddings, } knowledge_embedding_client = KnowledgeEmbedding( - file_path="", model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - local_persist=False, vector_store_config=vector_store_config, ) if CFG.SUMMARY_CONFIG == "FAST": @@ -105,12 +99,10 @@ class DBSummaryClient: for table in related_tables: vector_store_config = { "vector_store_name": table + "_ts", - "embeddings": embeddings, } knowledge_embedding_client = KnowledgeEmbedding( file_path="", model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - local_persist=False, vector_store_config=vector_store_config, ) table_summery = knowledge_embedding_client.similar_search(query, 1) diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index 03c9633d3..018cd3000 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -19,36 +19,32 @@ CFG = Config() class LocalKnowledgeInit: embeddings: object = None - model_name = LLM_MODEL_CONFIG["text2vec"] top_k: int = VECTOR_SEARCH_TOP_K def __init__(self, vector_store_config) -> None: self.vector_store_config = vector_store_config + self.model_name = LLM_MODEL_CONFIG["text2vec"] def knowledge_persist(self, file_path, append_mode): """knowledge persist""" - kv = KnowledgeEmbedding( - file_path=file_path, - model_name=LLM_MODEL_CONFIG["text2vec"], - vector_store_config=self.vector_store_config, - ) - vector_store = kv.knowledge_persist_initialization(append_mode) - return vector_store + for root, _, files in os.walk(file_path, topdown=False): + for file in files: + filename = os.path.join(root, file) + # docs = self._load_file(filename) + ke = KnowledgeEmbedding( + file_path=filename, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + ) + client = ke.init_knowledge_embedding() + client.source_embedding() - def query(self, q): - """Query similar doc from Vector""" - vector_store = self.init_vector_store() - docs = vector_store.similarity_search_with_score(q, k=self.top_k) - for doc in docs: - dc, s = doc - yield s, dc if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--vector_name", type=str, default="default") parser.add_argument("--append", type=bool, default=False) - parser.add_argument("--store_type", type=str, default="Chroma") args = parser.parse_args() vector_name = args.vector_name append_mode = args.append @@ -56,5 +52,5 @@ if __name__ == "__main__": vector_store_config = {"vector_store_name": vector_name} print(vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config) - vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) + kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode) print("your knowledge embedding success...") From f2f28fee4232d4aea28c7c06218e998780cfa751 Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Mon, 5 Jun 2023 16:27:52 +0800 Subject: [PATCH 2/4] update:format --- pilot/datasets/mysql/url.md | 1 + pilot/model/adapter.py | 3 +-- pilot/model/guanaco_stream_llm.py | 5 ++--- pilot/scene/chat_knowledge/custom/prompt.py | 4 +++- pilot/scene/chat_knowledge/default/prompt.py | 4 +++- pilot/scene/chat_knowledge/url/prompt.py | 4 +++- pilot/server/chat_adapter.py | 6 +++++- pilot/source_embedding/EncodeTextLoader.py | 8 ++++---- pilot/source_embedding/knowledge_embedding.py | 15 +++++++++++---- pilot/source_embedding/pdf_embedding.py | 1 + tools/knowlege_init.py | 1 - 11 files changed, 34 insertions(+), 18 deletions(-) diff --git a/pilot/datasets/mysql/url.md b/pilot/datasets/mysql/url.md index e69de29bb..20592cb72 100644 --- a/pilot/datasets/mysql/url.md +++ b/pilot/datasets/mysql/url.md @@ -0,0 +1 @@ +LlamaIndex是一个数据框架,旨在帮助您构建LLM应用程序。它包括一个向量存储索引和一个简单的目录阅读器,可以帮助您处理和操作数据。此外,LlamaIndex还提供了一个GPT Index,可以用于数据增强和生成更好的LM模型。 \ No newline at end of file diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 8f32b86fc..05c55fa74 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -82,7 +82,7 @@ class ChatGLMAdapater(BaseLLMAdaper): ) return model, tokenizer - + class GuanacoAdapter(BaseLLMAdaper): """TODO Support guanaco""" @@ -97,7 +97,6 @@ class GuanacoAdapter(BaseLLMAdaper): return model, tokenizer - class GuanacoAdapter(BaseLLMAdaper): """TODO Support guanaco""" diff --git a/pilot/model/guanaco_stream_llm.py b/pilot/model/guanaco_stream_llm.py index 8f72699d1..be70b6a18 100644 --- a/pilot/model/guanaco_stream_llm.py +++ b/pilot/model/guanaco_stream_llm.py @@ -3,7 +3,6 @@ from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria - def guanaco_stream_generate_output(model, tokenizer, params, device, context_len=2048): """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" tokenizer.bos_token_id = 1 @@ -19,7 +18,7 @@ def guanaco_stream_generate_output(model, tokenizer, params, device, context_len streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) - + tokenizer.bos_token_id = 1 stop_token_ids = [0] @@ -52,4 +51,4 @@ def guanaco_stream_generate_output(model, tokenizer, params, device, context_len for new_text in streamer: out += new_text yield new_text - return out \ No newline at end of file + return out diff --git a/pilot/scene/chat_knowledge/custom/prompt.py b/pilot/scene/chat_knowledge/custom/prompt.py index a76bb70ba..4892e28cd 100644 --- a/pilot/scene/chat_knowledge/custom/prompt.py +++ b/pilot/scene/chat_knowledge/custom/prompt.py @@ -28,7 +28,9 @@ _DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users w {question} """ -_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/scene/chat_knowledge/default/prompt.py b/pilot/scene/chat_knowledge/default/prompt.py index 5dc512898..0fd9f9ff3 100644 --- a/pilot/scene/chat_knowledge/default/prompt.py +++ b/pilot/scene/chat_knowledge/default/prompt.py @@ -29,7 +29,9 @@ _DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users w {question} """ -_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/scene/chat_knowledge/url/prompt.py b/pilot/scene/chat_knowledge/url/prompt.py index e887cc53f..3e9659130 100644 --- a/pilot/scene/chat_knowledge/url/prompt.py +++ b/pilot/scene/chat_knowledge/url/prompt.py @@ -28,7 +28,9 @@ _DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users w {question} """ -_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) PROMPT_SEP = SeparatorStyle.SINGLE.value diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 4743c4159..0bd56380b 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -59,6 +59,7 @@ class ChatGLMChatAdapter(BaseChatAdpter): return chatglm_generate_stream + class GuanacoChatAdapter(BaseChatAdpter): """Model chat adapter for Guanaco""" @@ -66,10 +67,13 @@ class GuanacoChatAdapter(BaseChatAdpter): return "guanaco" in model_path def get_generate_stream_func(self): - from pilot.model.llm_out.guanaco_stream_llm import guanaco_stream_generate_output + from pilot.model.llm_out.guanaco_stream_llm import ( + guanaco_stream_generate_output, + ) return guanaco_generate_output + class CodeT5ChatAdapter(BaseChatAdpter): """Model chat adapter for CodeT5""" diff --git a/pilot/source_embedding/EncodeTextLoader.py b/pilot/source_embedding/EncodeTextLoader.py index 7f3ba7e5a..2b7344f18 100644 --- a/pilot/source_embedding/EncodeTextLoader.py +++ b/pilot/source_embedding/EncodeTextLoader.py @@ -15,12 +15,12 @@ class EncodeTextLoader(BaseLoader): def load(self) -> List[Document]: """Load from file path.""" - with open(self.file_path, 'rb') as f: + with open(self.file_path, "rb") as f: raw_text = f.read() result = chardet.detect(raw_text) - if result['encoding'] is None: - text = raw_text.decode('utf-8') + if result["encoding"] is None: + text = raw_text.decode("utf-8") else: - text = raw_text.decode(result['encoding']) + text = raw_text.decode(result["encoding"]) metadata = {"source": self.file_path} return [Document(page_content=text, metadata=metadata)] diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 97d9f590b..c81953ffc 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -20,13 +20,14 @@ CFG = Config() KnowledgeEmbeddingType = { ".txt": (MarkdownEmbedding, {}), - ".md": (MarkdownEmbedding,{}), + ".md": (MarkdownEmbedding, {}), ".pdf": (PDFEmbedding, {}), ".doc": (WordEmbedding, {}), ".docx": (WordEmbedding, {}), ".csv": (CSVEmbedding, {}), } + class KnowledgeEmbedding: def __init__( self, @@ -34,7 +35,6 @@ class KnowledgeEmbedding: vector_store_config, file_type: Optional[str] = "default", file_path: Optional[str] = None, - ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path @@ -62,13 +62,20 @@ class KnowledgeEmbedding: extension = "." + self.file_path.rsplit(".", 1)[-1] if extension in KnowledgeEmbeddingType: knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] - embedding = knowledge_class(self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config, **knowledge_args) + embedding = knowledge_class( + self.file_path, + model_name=self.model_name, + vector_store_config=self.vector_store_config, + **knowledge_args, + ) return embedding raise ValueError(f"Unsupported knowledge file type '{extension}'") return embedding def similar_search(self, text, topk): - vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config) + vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) return vector_client.similar_search(text, topk) def vector_exist(self): diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index b3f64b788..55f3783f3 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -20,6 +20,7 @@ class PDFEmbedding(SourceEmbedding): self.model_name = model_name self.vector_store_config = vector_store_config self.encoding = encoding + @register def read(self): """Load from pdf path.""" diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index 018cd3000..e886e4d85 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -40,7 +40,6 @@ class LocalKnowledgeInit: client.source_embedding() - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--vector_name", type=str, default="default") From e29fa37cde7d62bd200bfa82a29e23435c38d64c Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Mon, 5 Jun 2023 18:08:55 +0800 Subject: [PATCH 3/4] update:knowledge env --- .env.template | 8 +++- pilot/configs/config.py | 4 +- pilot/scene/chat_knowledge/custom/chat.py | 3 +- pilot/scene/chat_knowledge/default/chat.py | 3 +- pilot/scene/chat_knowledge/url/chat.py | 3 +- pilot/server/vectordb_qa.py | 6 ++- pilot/server/webserver.py | 1 - pilot/source_embedding/csv_embedding.py | 4 +- pilot/source_embedding/knowledge_embedding.py | 1 - pilot/source_embedding/markdown_embedding.py | 11 +++--- pilot/source_embedding/pdf_embedding.py | 11 +++--- pilot/source_embedding/source_embedding.py | 2 - pilot/source_embedding/string_embedding.py | 5 +-- pilot/source_embedding/url_embedding.py | 7 ++-- pilot/source_embedding/word_embedding.py | 11 +++--- pilot/summary/db_summary_client.py | 39 ++----------------- pilot/vector_store/file_loader.py | 2 - tools/knowlege_init.py | 2 - 18 files changed, 43 insertions(+), 80 deletions(-) diff --git a/.env.template b/.env.template index 3e8ae536b..234b12738 100644 --- a/.env.template +++ b/.env.template @@ -28,8 +28,12 @@ MAX_POSITION_EMBEDDINGS=4096 # FAST_LLM_MODEL=chatglm-6b -### EMBEDDINGS -## EMBEDDING_MODEL - Model to use for creating embeddings +#*******************************************************************# +#** EMBEDDING SETTINGS **# +#*******************************************************************# +EMBEDDING_MODEL=text2vec +KNOWLEDGE_CHUNK_SIZE=500 +KNOWLEDGE_SEARCH_TOP_SIZE=5 ## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs ## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs # EMBEDDING_MODEL=all-MiniLM-L6-v2 diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 6f6271477..c4458eaf7 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -148,8 +148,8 @@ class Config(metaclass=Singleton): ### EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") - self.KNOWLEDGE_CHUNK_SIZE = os.getenv("KNOWLEDGE_CHUNK_SIZE", 100) - self.KNOWLEDGE_SEARCH_TOP_SIZE = os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10) + self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 500)) + self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10)) ### SUMMARY_CONFIG Configuration self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR") diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 8fc0f3d82..a56b2a098 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -14,7 +14,6 @@ from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, LOGDIR, - VECTOR_SEARCH_TOP_K, ) from pilot.scene.chat_knowledge.custom.prompt import prompt @@ -52,7 +51,7 @@ class ChatNewKnowledge(BaseChat): def generate_input_values(self): docs = self.knowledge_embedding_client.similar_search( - self.current_user_input, VECTOR_SEARCH_TOP_K + self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE ) context = [d.page_content for d in docs] context = context[:2000] diff --git a/pilot/scene/chat_knowledge/default/chat.py b/pilot/scene/chat_knowledge/default/chat.py index 1087ee2c0..325b03783 100644 --- a/pilot/scene/chat_knowledge/default/chat.py +++ b/pilot/scene/chat_knowledge/default/chat.py @@ -14,7 +14,6 @@ from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, LOGDIR, - VECTOR_SEARCH_TOP_K, ) from pilot.scene.chat_knowledge.default.prompt import prompt @@ -48,7 +47,7 @@ class ChatDefaultKnowledge(BaseChat): def generate_input_values(self): docs = self.knowledge_embedding_client.similar_search( - self.current_user_input, VECTOR_SEARCH_TOP_K + self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE ) context = [d.page_content for d in docs] context = context[:2000] diff --git a/pilot/scene/chat_knowledge/url/chat.py b/pilot/scene/chat_knowledge/url/chat.py index 2634dc80d..88dc7ad0b 100644 --- a/pilot/scene/chat_knowledge/url/chat.py +++ b/pilot/scene/chat_knowledge/url/chat.py @@ -14,7 +14,6 @@ from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, LLM_MODEL_CONFIG, LOGDIR, - VECTOR_SEARCH_TOP_K, ) from pilot.scene.chat_knowledge.url.prompt import prompt @@ -56,7 +55,7 @@ class ChatUrlKnowledge(BaseChat): def generate_input_values(self): docs = self.knowledge_embedding_client.similar_search( - self.current_user_input, VECTOR_SEARCH_TOP_K + self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE ) context = [d.page_content for d in docs] context = context[:2000] diff --git a/pilot/server/vectordb_qa.py b/pilot/server/vectordb_qa.py index 9faae5eb8..2a09e6a98 100644 --- a/pilot/server/vectordb_qa.py +++ b/pilot/server/vectordb_qa.py @@ -3,12 +3,14 @@ from langchain.prompts import PromptTemplate -from pilot.configs.model_config import VECTOR_SEARCH_TOP_K +from pilot.configs.config import Config from pilot.conversation import conv_qa_prompt_template, conv_db_summary_templates from pilot.logs import logger from pilot.model.llm_out.vicuna_llm import VicunaLLM from pilot.vector_store.file_loader import KnownLedge2Vector +CFG = Config() + class KnownLedgeBaseQA: def __init__(self) -> None: @@ -22,7 +24,7 @@ class KnownLedgeBaseQA: ) retriever = self.vector_store.as_retriever( - search_kwargs={"k": VECTOR_SEARCH_TOP_K} + search_kwargs={"k": CFG.KNOWLEDGE_SEARCH_TOP_SIZE} ) docs = retriever.get_relevant_documents(query=query) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index f7655fd7d..239fc5d9e 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -634,7 +634,6 @@ def knowledge_embedding_store(vs_id, files): knowledge_embedding_client = KnowledgeEmbedding( file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), model_name=LLM_MODEL_CONFIG["text2vec"], - local_persist=False, vector_store_config={ "vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, diff --git a/pilot/source_embedding/csv_embedding.py b/pilot/source_embedding/csv_embedding.py index 8b2e25ff3..0e69574b4 100644 --- a/pilot/source_embedding/csv_embedding.py +++ b/pilot/source_embedding/csv_embedding.py @@ -12,14 +12,12 @@ class CSVEmbedding(SourceEmbedding): def __init__( self, file_path, - model_name, vector_store_config, embedding_args: Optional[Dict] = None, ): """Initialize with csv path.""" - super().__init__(file_path, model_name, vector_store_config) + super().__init__(file_path, vector_store_config) self.file_path = file_path - self.model_name = model_name self.vector_store_config = vector_store_config self.embedding_args = embedding_args diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index c81953ffc..1e072c861 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -64,7 +64,6 @@ class KnowledgeEmbedding: knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension] embedding = knowledge_class( self.file_path, - model_name=self.model_name, vector_store_config=self.vector_store_config, **knowledge_args, ) diff --git a/pilot/source_embedding/markdown_embedding.py b/pilot/source_embedding/markdown_embedding.py index da974c366..e2851d122 100644 --- a/pilot/source_embedding/markdown_embedding.py +++ b/pilot/source_embedding/markdown_embedding.py @@ -8,20 +8,21 @@ from bs4 import BeautifulSoup from langchain.document_loaders import TextLoader from langchain.schema import Document -from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE +from pilot.configs.config import Config from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter +CFG = Config() + class MarkdownEmbedding(SourceEmbedding): """markdown embedding for read markdown document.""" - def __init__(self, file_path, model_name, vector_store_config): + def __init__(self, file_path, vector_store_config): """Initialize with markdown path.""" - super().__init__(file_path, model_name, vector_store_config) + super().__init__(file_path, vector_store_config) self.file_path = file_path - self.model_name = model_name self.vector_store_config = vector_store_config # self.encoding = encoding @@ -30,7 +31,7 @@ class MarkdownEmbedding(SourceEmbedding): """Load from markdown path.""" loader = EncodeTextLoader(self.file_path) text_splitter = CHNDocumentSplitter( - pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE ) return loader.load_and_split(text_splitter) diff --git a/pilot/source_embedding/pdf_embedding.py b/pilot/source_embedding/pdf_embedding.py index 55f3783f3..6eced03f3 100644 --- a/pilot/source_embedding/pdf_embedding.py +++ b/pilot/source_embedding/pdf_embedding.py @@ -5,19 +5,20 @@ from typing import List from langchain.document_loaders import PyPDFLoader from langchain.schema import Document -from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE +from pilot.configs.config import Config from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter +CFG = Config() + class PDFEmbedding(SourceEmbedding): """pdf embedding for read pdf document.""" - def __init__(self, file_path, model_name, vector_store_config, encoding): + def __init__(self, file_path, vector_store_config, encoding): """Initialize with pdf path.""" - super().__init__(file_path, model_name, vector_store_config) + super().__init__(file_path, vector_store_config) self.file_path = file_path - self.model_name = model_name self.vector_store_config = vector_store_config self.encoding = encoding @@ -27,7 +28,7 @@ class PDFEmbedding(SourceEmbedding): # loader = UnstructuredPaddlePDFLoader(self.file_path) loader = PyPDFLoader(self.file_path) textsplitter = CHNDocumentSplitter( - pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE ) return loader.load_and_split(textsplitter) diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index 7db92ea9b..50c7044f9 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -23,13 +23,11 @@ class SourceEmbedding(ABC): def __init__( self, file_path, - model_name, vector_store_config, embedding_args: Optional[Dict] = None, ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path - self.model_name = model_name self.vector_store_config = vector_store_config self.embedding_args = embedding_args self.embeddings = vector_store_config["embeddings"] diff --git a/pilot/source_embedding/string_embedding.py b/pilot/source_embedding/string_embedding.py index b4d7b1228..a1d18ee82 100644 --- a/pilot/source_embedding/string_embedding.py +++ b/pilot/source_embedding/string_embedding.py @@ -8,11 +8,10 @@ from pilot import SourceEmbedding, register class StringEmbedding(SourceEmbedding): """string embedding for read string document.""" - def __init__(self, file_path, model_name, vector_store_config): + def __init__(self, file_path, vector_store_config): """Initialize with pdf path.""" - super().__init__(file_path, model_name, vector_store_config) + super().__init__(file_path, vector_store_config) self.file_path = file_path - self.model_name = model_name self.vector_store_config = vector_store_config @register diff --git a/pilot/source_embedding/url_embedding.py b/pilot/source_embedding/url_embedding.py index 39224a9f4..a315e6e45 100644 --- a/pilot/source_embedding/url_embedding.py +++ b/pilot/source_embedding/url_embedding.py @@ -16,11 +16,10 @@ CFG = Config() class URLEmbedding(SourceEmbedding): """url embedding for read url document.""" - def __init__(self, file_path, model_name, vector_store_config): + def __init__(self, file_path, vector_store_config): """Initialize with url path.""" - super().__init__(file_path, model_name, vector_store_config) + super().__init__(file_path, vector_store_config) self.file_path = file_path - self.model_name = model_name self.vector_store_config = vector_store_config @register @@ -29,7 +28,7 @@ class URLEmbedding(SourceEmbedding): loader = WebBaseLoader(web_path=self.file_path) if CFG.LANGUAGE == "en": text_splitter = CharacterTextSplitter( - chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, + chunk_size=CFG.KNOWLEDGE_CHUNK_SIZE, chunk_overlap=20, length_function=len, ) diff --git a/pilot/source_embedding/word_embedding.py b/pilot/source_embedding/word_embedding.py index 5dd2f0199..1f30f241c 100644 --- a/pilot/source_embedding/word_embedding.py +++ b/pilot/source_embedding/word_embedding.py @@ -5,19 +5,20 @@ from typing import List from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader from langchain.schema import Document -from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE +from pilot.configs.config import Config from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter +CFG = Config() + class WordEmbedding(SourceEmbedding): """word embedding for read word document.""" - def __init__(self, file_path, model_name, vector_store_config): + def __init__(self, file_path, vector_store_config): """Initialize with word path.""" - super().__init__(file_path, model_name, vector_store_config) + super().__init__(file_path, vector_store_config) self.file_path = file_path - self.model_name = model_name self.vector_store_config = vector_store_config @register @@ -25,7 +26,7 @@ class WordEmbedding(SourceEmbedding): """Load from word path.""" loader = UnstructuredWordDocumentLoader(self.file_path) textsplitter = CHNDocumentSplitter( - pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE + pdf=True, sentence_size=CFG.KNOWLEDGE_CHUNK_SIZE ) return loader.load_and_split(textsplitter) diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index c5bfcc718..3dfbede72 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -34,24 +34,21 @@ class DBSummaryClient: "embeddings": embeddings, } embedding = StringEmbedding( - db_summary_client.get_summery(), - LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - vector_store_config, + file_path=db_summary_client.get_summery(), + vector_store_config=vector_store_config, ) if not embedding.vector_name_exist(): if CFG.SUMMARY_CONFIG == "FAST": for vector_table_info in db_summary_client.get_summery(): embedding = StringEmbedding( vector_table_info, - LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], vector_store_config, ) embedding.source_embedding() else: embedding = StringEmbedding( - db_summary_client.get_summery(), - LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], - vector_store_config, + file_path=db_summary_client.get_summery(), + vector_store_config=vector_store_config, ) embedding.source_embedding() for ( @@ -64,7 +61,6 @@ class DBSummaryClient: } embedding = StringEmbedding( table_summary, - LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], table_vector_store_config, ) embedding.source_embedding() @@ -124,30 +120,3 @@ def _get_llm_response(query, db_input, dbsummary): ) res = chat.nostream_call() return json.loads(res)["table"] - - -# if __name__ == "__main__": -# # summary = DBSummaryClient.get_similar_tables("db_test", "查询在线用户的购物车", 10) -# -# text= """Based on the input "查询在线聊天的用户好友" and the known database information, the tables involved in the user input are "chat_users" and "friends". -# Response: -# -# { -# "table": ["chat_users"] -# }""" -# text = text.rstrip().replace("\n","") -# start = text.find("{") -# end = text.find("}") + 1 -# -# # 从字符串中截取出JSON数据 -# json_str = text[start:end] -# -# # 将JSON数据转换为Python中的字典类型 -# data = json.loads(json_str) -# # pattern = r'{s*"table"s*:s*[[^]]*]s*}' -# # match = re.search(pattern, text) -# # if match: -# # json_string = match.group(0) -# # # 将JSON字符串转换为Python对象 -# # json_obj = json.loads(json_string) -# # print(summary) diff --git a/pilot/vector_store/file_loader.py b/pilot/vector_store/file_loader.py index c42eda7a6..cca027324 100644 --- a/pilot/vector_store/file_loader.py +++ b/pilot/vector_store/file_loader.py @@ -17,7 +17,6 @@ from langchain.vectorstores import Chroma from pilot.configs.model_config import ( DATASETS_DIR, LLM_MODEL_CONFIG, - VECTOR_SEARCH_TOP_K, VECTORE_PATH, ) @@ -41,7 +40,6 @@ class KnownLedge2Vector: embeddings: object = None model_name = LLM_MODEL_CONFIG["sentence-transforms"] - top_k: int = VECTOR_SEARCH_TOP_K def __init__(self, model_name=None) -> None: if not model_name: diff --git a/tools/knowlege_init.py b/tools/knowlege_init.py index e886e4d85..ff13865b4 100644 --- a/tools/knowlege_init.py +++ b/tools/knowlege_init.py @@ -10,7 +10,6 @@ from pilot.configs.config import Config from pilot.configs.model_config import ( DATASETS_DIR, LLM_MODEL_CONFIG, - VECTOR_SEARCH_TOP_K, ) from pilot.source_embedding.knowledge_embedding import KnowledgeEmbedding @@ -19,7 +18,6 @@ CFG = Config() class LocalKnowledgeInit: embeddings: object = None - top_k: int = VECTOR_SEARCH_TOP_K def __init__(self, vector_store_config) -> None: self.vector_store_config = vector_store_config From 4b41842277ac2f7a0c596316ec2d845c87de65d2 Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Mon, 5 Jun 2023 21:54:30 +0800 Subject: [PATCH 4/4] feature:db_summary bootstrap load --- pilot/scene/chat_db/auto_execute/chat.py | 3 +- pilot/scene/chat_db/professional_qa/chat.py | 3 +- pilot/server/webserver.py | 15 +++-- pilot/source_embedding/knowledge_embedding.py | 64 ------------------- pilot/summary/db_summary_client.py | 19 ++++-- 5 files changed, 26 insertions(+), 78 deletions(-) diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 1f4597789..73c732713 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -47,12 +47,13 @@ class ChatWithDbAutoExecute(BaseChat): from pilot.summary.db_summary_client import DBSummaryClient except ImportError: raise ValueError("Could not import DBSummaryClient. ") + client = DBSummaryClient() input_values = { "input": self.current_user_input, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": self.database.table_simple_info(self.db_connect) - # "table_info": DBSummaryClient.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) + # "table_info": client.get_similar_tables(dbname=self.db_name, query=self.current_user_input, topk=self.top_k) } return input_values diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 66b751533..faffcc146 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -45,7 +45,8 @@ class ChatWithDbQA(BaseChat): except ImportError: raise ValueError("Could not import DBSummaryClient. ") if self.db_name: - table_info = DBSummaryClient.get_similar_tables( + client = DBSummaryClient() + table_info = client.get_similar_tables( dbname=self.db_name, query=self.current_user_input, topk=self.top_k ) # table_info = self.database.table_simple_info(self.db_connect) diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index 239fc5d9e..f8626f7b4 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import threading import traceback import argparse import datetime @@ -414,7 +415,7 @@ def build_single_model_ui(): show_label=True, ).style(container=False) - db_selector.change(fn=db_selector_changed, inputs=db_selector) + # db_selector.change(fn=db_selector_changed, inputs=db_selector) sql_mode = gr.Radio( [ @@ -618,10 +619,6 @@ def save_vs_name(vs_name): return vs_name -def db_selector_changed(dbname): - DBSummaryClient.db_summary_embedding(dbname) - - def knowledge_embedding_store(vs_id, files): # vs_path = os.path.join(VS_ROOT_PATH, vs_id) if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id)): @@ -645,6 +642,12 @@ def knowledge_embedding_store(vs_id, files): return vs_id +def async_db_summery(): + client = DBSummaryClient() + thread = threading.Thread(target=client.init_db_summary) + thread.start() + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") @@ -661,7 +664,7 @@ if __name__ == "__main__": cfg = Config() dbs = cfg.local_db.get_database_list() - + async_db_summery() cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) # 加载插件可执行命令 diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 1e072c861..27297111b 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -1,14 +1,8 @@ -import os from typing import Optional -import markdown -from bs4 import BeautifulSoup -from langchain.document_loaders import PyPDFLoader, TextLoader from langchain.embeddings import HuggingFaceEmbeddings from pilot.configs.config import Config -from pilot.configs.model_config import DATASETS_DIR, KNOWLEDGE_CHUNK_SPLIT_SIZE -from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.csv_embedding import CSVEmbedding from pilot.source_embedding.markdown_embedding import MarkdownEmbedding from pilot.source_embedding.pdf_embedding import PDFEmbedding @@ -82,61 +76,3 @@ class KnowledgeEmbedding: CFG.VECTOR_STORE_TYPE, self.vector_store_config ) return vector_client.vector_name_exists() - - def knowledge_persist_initialization(self, append_mode): - documents = self._load_knownlege(self.file_path) - self.vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, self.vector_store_config - ) - self.vector_client.load_document(documents) - return self.vector_client - - def _load_knownlege(self, path): - docments = [] - for root, _, files in os.walk(path, topdown=False): - for file in files: - filename = os.path.join(root, file) - docs = self._load_file(filename) - new_docs = [] - for doc in docs: - doc.metadata = { - "source": doc.metadata["source"].replace(DATASETS_DIR, "") - } - print("doc is embedding...", doc.metadata) - new_docs.append(doc) - docments += new_docs - return docments - - def _load_file(self, filename): - if filename.lower().endswith(".md"): - loader = TextLoader(filename) - text_splitter = CHNDocumentSplitter( - pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE - ) - docs = loader.load_and_split(text_splitter) - i = 0 - for d in docs: - content = markdown.markdown(d.page_content) - soup = BeautifulSoup(content, "html.parser") - for tag in soup(["!doctype", "meta", "i.fa"]): - tag.extract() - docs[i].page_content = soup.get_text() - docs[i].page_content = docs[i].page_content.replace("\n", " ") - i += 1 - elif filename.lower().endswith(".pdf"): - loader = PyPDFLoader(filename) - textsplitter = CHNDocumentSplitter( - pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE - ) - docs = loader.load_and_split(textsplitter) - i = 0 - for d in docs: - docs[i].page_content = d.page_content.replace("\n", " ").replace( - "�", "" - ) - i += 1 - else: - loader = TextLoader(filename) - text_splitor = CHNDocumentSplitter(sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE) - docs = loader.load_and_split(text_splitor) - return docs diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 3dfbede72..51f124f62 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -21,8 +21,10 @@ class DBSummaryClient: , get_similar_tables method(get user query related tables info) """ - @staticmethod - def db_summary_embedding(dbname): + def __init__(self): + pass + + def db_summary_embedding(self, dbname): """put db profile and table profile summary into vector store""" if CFG.LOCAL_DB_HOST is not None and CFG.LOCAL_DB_PORT is not None: db_summary_client = MysqlSummary(dbname) @@ -56,7 +58,7 @@ class DBSummaryClient: table_summary, ) in db_summary_client.get_table_summary().items(): table_vector_store_config = { - "vector_store_name": table_name + "_ts", + "vector_store_name": dbname + "_" + table_name + "_ts", "embeddings": embeddings, } embedding = StringEmbedding( @@ -67,8 +69,7 @@ class DBSummaryClient: logger.info("db summary embedding success") - @staticmethod - def get_similar_tables(dbname, query, topk): + def get_similar_tables(self, dbname, query, topk): """get user query related tables info""" vector_store_config = { "vector_store_name": dbname + "_profile", @@ -94,7 +95,7 @@ class DBSummaryClient: related_table_summaries = [] for table in related_tables: vector_store_config = { - "vector_store_name": table + "_ts", + "vector_store_name": dbname + "_" + table + "_ts", } knowledge_embedding_client = KnowledgeEmbedding( file_path="", @@ -105,6 +106,12 @@ class DBSummaryClient: related_table_summaries.append(table_summery[0].page_content) return related_table_summaries + def init_db_summary(self): + db = CFG.local_db + dbs = db.get_database_list() + for dbname in dbs: + self.db_summary_embedding(dbname) + def _get_llm_response(query, db_input, dbsummary): chat_param = {