mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-10 05:19:44 +00:00
feature:add knowledge embedding
This commit is contained in:
@@ -10,6 +10,7 @@ class CSVEmbedding(SourceEmbedding):
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
|
||||
"""Initialize with csv path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
|
@@ -4,17 +4,31 @@ from pilot.source_embedding.pdf_embedding import PDFEmbedding
|
||||
|
||||
|
||||
class KnowledgeEmbedding:
|
||||
@staticmethod
|
||||
def knowledge_embedding(file_path:str, model_name, vector_store_config):
|
||||
if file_path.endswith(".pdf"):
|
||||
embedding = PDFEmbedding(file_path=file_path, model_name=model_name,
|
||||
vector_store_config=vector_store_config)
|
||||
elif file_path.endswith(".md"):
|
||||
embedding = MarkdownEmbedding(file_path=file_path, model_name=model_name,
|
||||
vector_store_config=vector_store_config)
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.vector_store_type = "default"
|
||||
self.knowledge_embedding_client = self.init_knowledge_embedding()
|
||||
|
||||
elif file_path.endswith(".csv"):
|
||||
embedding = CSVEmbedding(file_path=file_path, model_name=model_name,
|
||||
vector_store_config=vector_store_config)
|
||||
def knowledge_embedding(self):
|
||||
self.knowledge_embedding_client.source_embedding()
|
||||
|
||||
return embedding
|
||||
def init_knowledge_embedding(self):
|
||||
if self.file_path.endswith(".pdf"):
|
||||
embedding = PDFEmbedding(file_path=self.file_path, model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config)
|
||||
elif self.file_path.endswith(".md"):
|
||||
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
|
||||
|
||||
elif self.file_path.endswith(".csv"):
|
||||
embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config)
|
||||
elif self.vector_store_type == "default":
|
||||
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
|
||||
|
||||
return embedding
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
return self.knowledge_embedding_client.similar_search(text, topk)
|
@@ -15,6 +15,7 @@ class MarkdownEmbedding(SourceEmbedding):
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
"""Initialize with markdown path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
|
@@ -13,9 +13,12 @@ class PDFEmbedding(SourceEmbedding):
|
||||
|
||||
def __init__(self, file_path, model_name, vector_store_config):
|
||||
"""Initialize with pdf path."""
|
||||
super().__init__(file_path, model_name, vector_store_config)
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
# SourceEmbedding(file_path =file_path, );
|
||||
SourceEmbedding(file_path, model_name, vector_store_config)
|
||||
|
||||
@register
|
||||
def read(self):
|
||||
|
@@ -50,7 +50,7 @@
|
||||
#
|
||||
# # text_embeddings = Text2Vectors()
|
||||
# mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"})
|
||||
#
|
||||
#
|
||||
# mivuls.insert(["textc","tezt2"])
|
||||
# print("success")
|
||||
# ct
|
||||
|
@@ -22,12 +22,16 @@ class SourceEmbedding(ABC):
|
||||
Implementations should implement the method
|
||||
"""
|
||||
|
||||
def __init__(self, yuque_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
|
||||
"""Initialize with YuqueLoader url, model_name, vector_store_config"""
|
||||
self.yuque_path = yuque_path
|
||||
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
|
||||
"""Initialize with Loader url, model_name, vector_store_config"""
|
||||
self.file_path = file_path
|
||||
self.model_name = model_name
|
||||
self.vector_store_config = vector_store_config
|
||||
self.embedding_args = embedding_args
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
||||
self.vector_store_config["vector_store_name"] + ".vectordb")
|
||||
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
||||
|
||||
@abstractmethod
|
||||
@register
|
||||
@@ -50,18 +54,16 @@ class SourceEmbedding(ABC):
|
||||
@register
|
||||
def index_to_store(self, docs):
|
||||
"""index to vector store"""
|
||||
embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
|
||||
|
||||
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
|
||||
self.vector_store_config["vector_store_name"] + ".vectordb")
|
||||
self.vector_store = Chroma.from_documents(docs, embeddings, persist_directory=persist_dir)
|
||||
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
|
||||
self.vector_store.persist()
|
||||
|
||||
@register
|
||||
def similar_search(self, doc, topk):
|
||||
"""vector store similarity_search"""
|
||||
|
||||
return self.vector_store.similarity_search(doc, topk)
|
||||
return self.vector_store_client.similarity_search(doc, topk)
|
||||
|
||||
def source_embedding(self):
|
||||
if 'read' in registered_methods:
|
||||
|
Reference in New Issue
Block a user