From 496696537d2a73d4076ca4560a434a10bca532ce Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 17 Oct 2023 11:52:45 +0800 Subject: [PATCH] fix:vectordb lazy load --- pilot/vector_store/__init__.py | 19 ++++++++----- pilot/vector_store/base.py | 2 +- pilot/vector_store/connector.py | 12 ++++---- pilot/vector_store/milvus_store.py | 15 +++++++--- pilot/vector_store/pgvector_store.py | 28 +++++++++---------- .../unit_tests/vector_store/test_pgvector.py | 7 +++-- 6 files changed, 47 insertions(+), 36 deletions(-) diff --git a/pilot/vector_store/__init__.py b/pilot/vector_store/__init__.py index daca3b81c..ff7e70dbc 100644 --- a/pilot/vector_store/__init__.py +++ b/pilot/vector_store/__init__.py @@ -1,21 +1,30 @@ from typing import Any + def _import_pgvector() -> Any: - from pilot.vector_store.pgvector_store import PGVectorStore + from pilot.vector_store.pgvector_store import PGVectorStore + return PGVectorStore + def _import_milvus() -> Any: from pilot.vector_store.milvus_store import MilvusStore + return MilvusStore + def _import_chroma() -> Any: from pilot.vector_store.chroma_store import ChromaStore + return ChromaStore + def _import_weaviate() -> Any: from pilot.vector_store.weaviate_store import WeaviateStore + return WeaviateStore + def __getattr__(name: str) -> Any: if name == "Chroma": return _import_chroma() @@ -28,9 +37,5 @@ def __getattr__(name: str) -> Any: else: raise AttributeError(f"Could not find: {name}") -__all__ = [ - "Chroma", - "Milvus", - "Weaviate", - "PGVector" -] \ No newline at end of file + +__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"] diff --git a/pilot/vector_store/base.py b/pilot/vector_store/base.py index 7eac8aa25..eb746c7a8 100644 --- a/pilot/vector_store/base.py +++ b/pilot/vector_store/base.py @@ -17,7 +17,7 @@ class VectorStoreBase(ABC): @abstractmethod def vector_name_exists(self) -> bool: """is vector store name exist.""" - return False + return False @abstractmethod def delete_by_ids(self, ids): diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index efc248aba..fd2198c0f 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -3,6 +3,7 @@ from pilot.vector_store.base import VectorStoreBase connector = {} + class VectorStoreConnector: """VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1. 1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate) @@ -16,16 +17,15 @@ class VectorStoreConnector: """initialize vector store connector.""" self.ctx = ctx self._register() - + if self._match(vector_store_type): self.connector_class = connector.get(vector_store_type) else: raise Exception(f"Vector Type Not support. {0}", vector_store_type) - - print(self.connector_class) + + print(self.connector_class) self.client = self.connector_class(ctx) - def load_document(self, docs): """load document in vector database.""" return self.client.load_document(docs) @@ -51,9 +51,9 @@ class VectorStoreConnector: return True else: return False - + def _register(self): for cls in vector_store.__all__: if issubclass(getattr(vector_store, cls), VectorStoreBase): _k, _v = cls, getattr(vector_store, cls) - connector.update({_k: _v}) \ No newline at end of file + connector.update({_k: _v}) diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 5deca8b47..ee304fe25 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -3,7 +3,6 @@ import logging import os from typing import Any, Iterable, List, Optional, Tuple -from pymilvus import Collection, DataType, connections, utility from pilot.vector_store.base import VectorStoreBase @@ -14,6 +13,8 @@ class MilvusStore(VectorStoreBase): """Milvus database""" def __init__(self, ctx: {}) -> None: + from pymilvus import Collection, DataType, connections, utility + """init a milvus storage connection. Args: @@ -85,6 +86,7 @@ class MilvusStore(VectorStoreBase): DataType, FieldSchema, connections, + utility, ) from pymilvus.orm.types import infer_dtype_bydata except ImportError: @@ -260,6 +262,8 @@ class MilvusStore(VectorStoreBase): return doc_ids def similar_search(self, text, topk) -> None: + from pymilvus import Collection, DataType + """similar_search in vector database.""" self.col = Collection(self.collection_name) schema = self.col.schema @@ -324,16 +328,22 @@ class MilvusStore(VectorStoreBase): return data[0], ret def vector_name_exists(self): + from pymilvus import utility + """is vector store name exist.""" return utility.has_collection(self.collection_name) def delete_vector_name(self, vector_name): + from pymilvus import utility + """milvus delete collection name""" logger.info(f"milvus vector_name:{vector_name} begin delete...") utility.drop_collection(vector_name) return True def delete_by_ids(self, ids): + from pymilvus import Collection + self.col = Collection(self.collection_name) """milvus delete vectors by ids""" logger.info(f"begin delete milvus ids...") @@ -342,6 +352,3 @@ class MilvusStore(VectorStoreBase): delet_expr = f"{self.primary_field} in {doc_ids}" self.col.delete(delet_expr) return True - - def close(self): - connections.disconnect() diff --git a/pilot/vector_store/pgvector_store.py b/pilot/vector_store/pgvector_store.py index 98ce4a027..5f6661871 100644 --- a/pilot/vector_store/pgvector_store.py +++ b/pilot/vector_store/pgvector_store.py @@ -7,32 +7,32 @@ logger = logging.getLogger(__name__) CFG = Config() + class PGVectorStore(VectorStoreBase): - """`Postgres.PGVector` vector store. - + """`Postgres.PGVector` vector store. + To use this, you should have the ``pgvector`` python package installed. """ def __init__(self, ctx: dict) -> None: """init pgvector storage""" - + from langchain.vectorstores import PGVector - + self.ctx = ctx self.connection_string = ctx.get("connection_string", None) self.embeddings = ctx.get("embeddings", None) self.collection_name = ctx.get("vector_store_name", None) - + self.vector_store_client = PGVector( embedding_function=self.embeddings, collection_name=self.collection_name, - connection_string=self.connection_string + connection_string=self.connection_string, ) - - def similar_search(self, text, topk, **kwargs: Any) -> None: - return self.vector_store_client.similarity_search(text, topk) - + def similar_search(self, text, topk, **kwargs: Any) -> None: + return self.vector_store_client.similarity_search(text, topk) + def vector_name_exists(self): try: self.vector_store_client.create_collection() @@ -40,14 +40,12 @@ class PGVectorStore(VectorStoreBase): except Exception as e: logger.error("vector_name_exists error", e.message) return False - + def load_document(self, documents) -> None: return self.vector_store_client.from_documents(documents) - def delete_vector_name(self, vector_name): - return self.vector_store_client.delete_collection() + return self.vector_store_client.delete_collection() - def delete_by_ids(self, ids): - return self.vector_store_client.delete(ids) \ No newline at end of file + return self.vector_store_client.delete(ids) diff --git a/tests/unit_tests/vector_store/test_pgvector.py b/tests/unit_tests/vector_store/test_pgvector.py index c96643683..59319a124 100644 --- a/tests/unit_tests/vector_store/test_pgvector.py +++ b/tests/unit_tests/vector_store/test_pgvector.py @@ -3,8 +3,9 @@ import pytest from pilot import vector_store from pilot.vector_store.base import VectorStoreBase -def test_vetorestore_imports() -> None: - """ Simple test to make sure all things can be imported.""" - for cls in vector_store.__all__: +def test_vetorestore_imports() -> None: + """Simple test to make sure all things can be imported.""" + + for cls in vector_store.__all__: assert issubclass(getattr(vector_store, cls), VectorStoreBase)