feature:add knowledge embedding

This commit is contained in:
aries-ckt
2023-05-15 22:12:50 +08:00
parent 3c795154b2
commit ce4c3e823d
8 changed files with 88 additions and 38 deletions

View File

@@ -10,6 +10,7 @@ class CSVEmbedding(SourceEmbedding):
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
"""Initialize with csv path."""
super().__init__(file_path, model_name, vector_store_config)
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config

View File

@@ -4,17 +4,31 @@ from pilot.source_embedding.pdf_embedding import PDFEmbedding
class KnowledgeEmbedding:
@staticmethod
def knowledge_embedding(file_path:str, model_name, vector_store_config):
if file_path.endswith(".pdf"):
embedding = PDFEmbedding(file_path=file_path, model_name=model_name,
vector_store_config=vector_store_config)
elif file_path.endswith(".md"):
embedding = MarkdownEmbedding(file_path=file_path, model_name=model_name,
vector_store_config=vector_store_config)
def __init__(self, file_path, model_name, vector_store_config):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config
self.vector_store_type = "default"
self.knowledge_embedding_client = self.init_knowledge_embedding()
elif file_path.endswith(".csv"):
embedding = CSVEmbedding(file_path=file_path, model_name=model_name,
vector_store_config=vector_store_config)
def knowledge_embedding(self):
self.knowledge_embedding_client.source_embedding()
return embedding
def init_knowledge_embedding(self):
if self.file_path.endswith(".pdf"):
embedding = PDFEmbedding(file_path=self.file_path, model_name=self.model_name,
vector_store_config=self.vector_store_config)
elif self.file_path.endswith(".md"):
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
elif self.file_path.endswith(".csv"):
embedding = CSVEmbedding(file_path=self.file_path, model_name=self.model_name,
vector_store_config=self.vector_store_config)
elif self.vector_store_type == "default":
embedding = MarkdownEmbedding(file_path=self.file_path, model_name=self.model_name, vector_store_config=self.vector_store_config)
return embedding
def similar_search(self, text, topk):
return self.knowledge_embedding_client.similar_search(text, topk)

View File

@@ -15,6 +15,7 @@ class MarkdownEmbedding(SourceEmbedding):
def __init__(self, file_path, model_name, vector_store_config):
"""Initialize with markdown path."""
super().__init__(file_path, model_name, vector_store_config)
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config

View File

@@ -13,9 +13,12 @@ class PDFEmbedding(SourceEmbedding):
def __init__(self, file_path, model_name, vector_store_config):
"""Initialize with pdf path."""
super().__init__(file_path, model_name, vector_store_config)
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config
# SourceEmbedding(file_path =file_path, );
SourceEmbedding(file_path, model_name, vector_store_config)
@register
def read(self):

View File

@@ -50,7 +50,7 @@
#
# # text_embeddings = Text2Vectors()
# mivuls = MilvusStore(cfg={"url": "127.0.0.1", "port": "19530", "alias": "default", "table_name": "test_k"})
#
#
# mivuls.insert(["textc","tezt2"])
# print("success")
# ct

View File

@@ -22,12 +22,16 @@ class SourceEmbedding(ABC):
Implementations should implement the method
"""
def __init__(self, yuque_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
"""Initialize with YuqueLoader url, model_name, vector_store_config"""
self.yuque_path = yuque_path
def __init__(self, file_path, model_name, vector_store_config, embedding_args: Optional[Dict] = None):
"""Initialize with Loader url, model_name, vector_store_config"""
self.file_path = file_path
self.model_name = model_name
self.vector_store_config = vector_store_config
self.embedding_args = embedding_args
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb")
self.vector_store_client = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
@abstractmethod
@register
@@ -50,18 +54,16 @@ class SourceEmbedding(ABC):
@register
def index_to_store(self, docs):
"""index to vector store"""
embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
persist_dir = os.path.join(self.vector_store_config["vector_store_path"],
self.vector_store_config["vector_store_name"] + ".vectordb")
self.vector_store = Chroma.from_documents(docs, embeddings, persist_directory=persist_dir)
self.vector_store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
self.vector_store.persist()
@register
def similar_search(self, doc, topk):
"""vector store similarity_search"""
return self.vector_store.similarity_search(doc, topk)
return self.vector_store_client.similarity_search(doc, topk)
def source_embedding(self):
if 'read' in registered_methods: