diff --git a/pilot/source_embedding/knowledge_embedding.py b/pilot/source_embedding/knowledge_embedding.py index 97b515897..d8ec85ae3 100644 --- a/pilot/source_embedding/knowledge_embedding.py +++ b/pilot/source_embedding/knowledge_embedding.py @@ -51,6 +51,7 @@ class KnowledgeEmbedding: self.knowledge_embedding_client.index_to_store(docs) def read(self): + self.knowledge_embedding_client = self.init_knowledge_embedding() return self.knowledge_embedding_client.read_batch() def init_knowledge_embedding(self): diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index 3d881fcdf..372e35c22 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -33,9 +33,6 @@ class SourceEmbedding(ABC): self.vector_store_config = vector_store_config self.embedding_args = embedding_args self.embeddings = vector_store_config["embeddings"] - self.vector_client = VectorStoreConnector( - CFG.VECTOR_STORE_TYPE, vector_store_config - ) @abstractmethod @register @@ -59,11 +56,17 @@ class SourceEmbedding(ABC): @register def index_to_store(self, docs): """index to vector store""" + self.vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) self.vector_client.load_document(docs) @register def similar_search(self, doc, topk): """vector store similarity_search""" + self.vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) try: ans = self.vector_client.similar_search(doc, topk) except NotEnoughElementsException: @@ -71,6 +74,9 @@ class SourceEmbedding(ABC): return ans def vector_name_exist(self): + self.vector_client = VectorStoreConnector( + CFG.VECTOR_STORE_TYPE, self.vector_store_config + ) return self.vector_client.vector_name_exists() def source_embedding(self): diff --git a/tools/knowledge_init.py b/tools/knowledge_init.py index c9a0c5457..74c62b90a 100644 --- a/tools/knowledge_init.py +++ b/tools/knowledge_init.py @@ -25,17 +25,20 @@ class LocalKnowledgeInit: def knowledge_persist(self, file_path): """knowledge persist""" + docs = [] + embedding_engine = None for root, _, files in os.walk(file_path, topdown=False): for file in files: filename = os.path.join(root, file) - # docs = self._load_file(filename) ke = KnowledgeEmbedding( file_path=filename, model_name=self.model_name, vector_store_config=self.vector_store_config, ) - client = ke.init_knowledge_embedding() - client.source_embedding() + embedding_engine = ke.init_knowledge_embedding() + doc = ke.read() + docs.extend(doc) + embedding_engine.index_to_store(docs) if __name__ == "__main__":