mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 00:37:34 +00:00
feature:knowledge embedding support file path auto adapt
This commit is contained in:
parent
31d457cfd5
commit
be1a792d3c
@ -148,6 +148,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
self.KNOWLEDGE_CHUNK_SIZE = os.getenv("KNOWLEDGE_CHUNK_SIZE", 100)
|
||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 10)
|
||||
### SUMMARY_CONFIG Configuration
|
||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "VECTOR")
|
||||
|
||||
|
@ -34,7 +34,6 @@ LLM_MODEL_CONFIG = {
|
||||
"chatglm-6b-int4": os.path.join(MODEL_PATH, "chatglm-6b-int4"),
|
||||
"chatglm-6b": os.path.join(MODEL_PATH, "chatglm-6b"),
|
||||
"text2vec-base": os.path.join(MODEL_PATH, "text2vec-base-chinese"),
|
||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||
"proxyllm": "proxyllm",
|
||||
|
@ -46,9 +46,7 @@ class ChatNewKnowledge(BaseChat):
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
|
@ -14,13 +14,21 @@ CFG = Config()
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
@ -42,9 +42,7 @@ class ChatDefaultKnowledge(BaseChat):
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
|
||||
|
@ -15,13 +15,21 @@ PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelli
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
@ -40,15 +40,13 @@ class ChatUrlKnowledge(BaseChat):
|
||||
self.url = url
|
||||
vector_store_config = {
|
||||
"vector_store_name": url,
|
||||
"text_field": "content",
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
}
|
||||
self.knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=url,
|
||||
file_type="url",
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
file_type="url",
|
||||
file_path=url,
|
||||
)
|
||||
|
||||
# url soruce in vector
|
||||
|
@ -14,20 +14,21 @@ CFG = Config()
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
|
||||
|
||||
# _DEFAULT_TEMPLATE = """ Based on the known information, provide professional and concise answers to the user's questions. If the answer cannot be obtained from the provided content, please say: 'The information provided in the knowledge base is not sufficient to answer this question.' Fabrication is prohibited.。
|
||||
# known information:
|
||||
# {context}
|
||||
# question:
|
||||
# {question}
|
||||
# """
|
||||
_DEFAULT_TEMPLATE = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
_DEFAULT_TEMPLATE_ZH = """ 基于以下已知的信息, 专业、简要的回答用户的问题,
|
||||
如果无法从提供的内容中获取答案, 请说: "知识库中提供的内容不足以回答此问题" 禁止胡乱编造。
|
||||
已知内容:
|
||||
{context}
|
||||
问题:
|
||||
{question}
|
||||
"""
|
||||
_DEFAULT_TEMPLATE_EN = """ Based on the known information below, provide users with professional and concise answers to their questions. If the answer cannot be obtained from the provided content, please say: "The information provided in the knowledge base is not sufficient to answer this question." It is forbidden to make up information randomly.
|
||||
known information:
|
||||
{context}
|
||||
question:
|
||||
{question}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == 'en' else _DEFAULT_TEMPLATE_ZH
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
26
pilot/source_embedding/EncodeTextLoader.py
Normal file
26
pilot/source_embedding/EncodeTextLoader.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import List, Optional
|
||||
import chardet
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class EncodeTextLoader(BaseLoader):
|
||||
"""Load text files."""
|
||||
|
||||
def __init__(self, file_path: str, encoding: Optional[str] = None):
|
||||
"""Initialize with file path."""
|
||||
self.file_path = file_path
|
||||
self.encoding = encoding
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
with open(self.file_path, 'rb') as f:
|
||||
raw_text = f.read()
|
||||
result = chardet.detect(raw_text)
|
||||
if result['encoding'] is None:
|
||||
text = raw_text.decode('utf-8')
|
||||
else:
|
||||
text = raw_text.decode(result['encoding'])
|
||||
metadata = {"source": self.file_path}
|
||||
return [Document(page_content=text, metadata=metadata)]
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
@ -12,19 +13,28 @@ from pilot.source_embedding.csv_embedding import CSVEmbedding
|
||||
from pilot.source_embedding.markdown_embedding import MarkdownEmbedding
|
||||
from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||
from pilot.source_embedding.url_embedding import URLEmbedding
|
||||
from pilot.source_embedding.word_embedding import WordEmbedding
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
|
||||
CFG = Config()
|
||||
|
||||
KnowledgeEmbeddingType = {
|
||||
".txt": (MarkdownEmbedding, {}),
|
||||
".md": (MarkdownEmbedding,{}),
|
||||
".pdf": (PDFEmbedding, {}),
|
||||
".doc": (WordEmbedding, {}),
|
||||
".docx": (WordEmbedding, {}),
|
||||
".csv": (CSVEmbedding, {}),
|
||||
}
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
file_path,
|
||||
model_name,
|
||||
vector_store_config,
|
||||
local_persist=True,
|
||||
file_type="default",
|
||||
file_type: Optional[str] = "default",
|
||||
file_path: Optional[str] = None,
|
||||
|
||||
):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
@ -33,11 +43,9 @@ class KnowledgeEmbedding:
|
||||
self.file_type = file_type
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
self.vector_store_config["embeddings"] = self.embeddings
|
||||
self.local_persist = local_persist
|
||||
if not self.local_persist:
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
|
||||
def knowledge_embedding(self):
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
self.knowledge_embedding_client.source_embedding()
|
||||
|
||||
def knowledge_embedding_batch(self):
|
||||
@ -50,40 +58,24 @@ class KnowledgeEmbedding:
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
elif self.file_path.endswith(".pdf"):
|
||||
embedding = PDFEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
elif self.file_path.endswith(".md"):
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
elif self.file_path.endswith(".csv"):
|
||||
embedding = CSVEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
elif self.file_type == "default":
|
||||
embedding = MarkdownEmbedding(
|
||||
file_path=self.file_path,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
|
||||
return embedding
|
||||
extension = "." + self.file_path.rsplit(".", 1)[-1]
|
||||
if extension in KnowledgeEmbeddingType:
|
||||
knowledge_class, knowledge_args = KnowledgeEmbeddingType[extension]
|
||||
embedding = knowledge_class(self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config, **knowledge_args)
|
||||
return embedding
|
||||
raise ValueError(f"Unsupported knowledge file type '{extension}'")
|
||||
return embedding
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
return self.knowledge_embedding_client.similar_search(text, topk)
|
||||
vector_client = VectorStoreConnector(CFG.VECTOR_STORE_TYPE, self.vector_store_config)
|
||||
return vector_client.similar_search(text, topk)
|
||||
|
||||
def vector_exist(self):
|
||||
return self.knowledge_embedding_client.vector_name_exist()
|
||||
vector_client = VectorStoreConnector(
|
||||
CFG.VECTOR_STORE_TYPE, self.vector_store_config
|
||||
)
|
||||
return vector_client.vector_name_exists()
|
||||
|
||||
def knowledge_persist_initialization(self, append_mode):
|
||||
documents = self._load_knownlege(self.file_path)
|
||||
|
@ -10,6 +10,7 @@ from langchain.schema import Document
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.EncodeTextLoader import EncodeTextLoader
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
|
||||
@ -22,11 +23,12 @@ class MarkdownEmbedding(SourceEmbedding):
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
# self.encoding = encoding
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from markdown path."""
|
||||
loader = TextLoader(self.file_path)
|
||||
loader = EncodeTextLoader(self.file_path)
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
|
@ -13,13 +13,13 @@ from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
class PDFEmbedding(SourceEmbedding):
|
||||
"""pdf embedding for read pdf document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
def __init__(self, file_path, model_name, vector_store_config, encoding):
|
||||
"""Initialize with pdf path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
self.encoding = encoding
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from pdf path."""
|
||||
|
38
pilot/source_embedding/word_embedding.py
Normal file
38
pilot/source_embedding/word_embedding.py
Normal file
@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
|
||||
from langchain.document_loaders import PyPDFLoader, UnstructuredWordDocumentLoader
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.configs.model_config import KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
|
||||
class WordEmbedding(SourceEmbedding):
|
||||
"""word embedding for read word document."""
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
"""Initialize with word path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
"""Load from word path."""
|
||||
loader = UnstructuredWordDocumentLoader(self.file_path)
|
||||
textsplitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=KNOWLEDGE_CHUNK_SPLIT_SIZE
|
||||
)
|
||||
return loader.load_and_split(textsplitter)
|
||||
|
||||
@register
|
||||
def data_process(self, documents: List[Document]):
|
||||
i = 0
|
||||
for d in documents:
|
||||
documents[i].page_content = d.page_content.replace("\n", "")
|
||||
i += 1
|
||||
return documents
|
@ -74,17 +74,11 @@ class DBSummaryClient:
|
||||
@staticmethod
|
||||
def get_similar_tables(dbname, query, topk):
|
||||
"""get user query related tables info"""
|
||||
embeddings = HuggingFaceEmbeddings(
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
)
|
||||
vector_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
@ -105,12 +99,10 @@ class DBSummaryClient:
|
||||
for table in related_tables:
|
||||
vector_store_config = {
|
||||
"vector_store_name": table + "_ts",
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path="",
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
local_persist=False,
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
||||
|
@ -19,36 +19,32 @@ CFG = Config()
|
||||
|
||||
class LocalKnowledgeInit:
|
||||
embeddings: object = None
|
||||
model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||
top_k: int = VECTOR_SEARCH_TOP_K
|
||||
|
||||
def __init__(self, vector_store_config) -> None:
|
||||
self.vector_store_config = vector_store_config
|
||||
self.model_name = LLM_MODEL_CONFIG["text2vec"]
|
||||
|
||||
def knowledge_persist(self, file_path, append_mode):
|
||||
"""knowledge persist"""
|
||||
kv = KnowledgeEmbedding(
|
||||
file_path=file_path,
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
vector_store = kv.knowledge_persist_initialization(append_mode)
|
||||
return vector_store
|
||||
for root, _, files in os.walk(file_path, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
# docs = self._load_file(filename)
|
||||
ke = KnowledgeEmbedding(
|
||||
file_path=filename,
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
client = ke.init_knowledge_embedding()
|
||||
client.source_embedding()
|
||||
|
||||
def query(self, q):
|
||||
"""Query similar doc from Vector"""
|
||||
vector_store = self.init_vector_store()
|
||||
docs = vector_store.similarity_search_with_score(q, k=self.top_k)
|
||||
for doc in docs:
|
||||
dc, s = doc
|
||||
yield s, dc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vector_name", type=str, default="default")
|
||||
parser.add_argument("--append", type=bool, default=False)
|
||||
parser.add_argument("--store_type", type=str, default="Chroma")
|
||||
args = parser.parse_args()
|
||||
vector_name = args.vector_name
|
||||
append_mode = args.append
|
||||
@ -56,5 +52,5 @@ if __name__ == "__main__":
|
||||
vector_store_config = {"vector_store_name": vector_name}
|
||||
print(vector_store_config)
|
||||
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||
vector_store = kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||
kv.knowledge_persist(file_path=DATASETS_DIR, append_mode=append_mode)
|
||||
print("your knowledge embedding success...")
|
||||
|
Loading…
Reference in New Issue
Block a user