diff --git a/pilot/vector_store/__init__.py b/pilot/vector_store/__init__.py index e69de29bb..daca3b81c 100644 --- a/pilot/vector_store/__init__.py +++ b/pilot/vector_store/__init__.py @@ -0,0 +1,36 @@ +from typing import Any + +def _import_pgvector() -> Any: + from pilot.vector_store.pgvector_store import PGVectorStore + return PGVectorStore + +def _import_milvus() -> Any: + from pilot.vector_store.milvus_store import MilvusStore + return MilvusStore + +def _import_chroma() -> Any: + from pilot.vector_store.chroma_store import ChromaStore + return ChromaStore + +def _import_weaviate() -> Any: + from pilot.vector_store.weaviate_store import WeaviateStore + return WeaviateStore + +def __getattr__(name: str) -> Any: + if name == "Chroma": + return _import_chroma() + elif name == "Milvus": + return _import_milvus() + elif name == "Weaviate": + return _import_weaviate() + elif name == "PGVector": + return _import_pgvector() + else: + raise AttributeError(f"Could not find: {name}") + +__all__ = [ + "Chroma", + "Milvus", + "Weaviate", + "PGVector" +] \ No newline at end of file diff --git a/pilot/vector_store/base.py b/pilot/vector_store/base.py index 74cd2f98c..7eac8aa25 100644 --- a/pilot/vector_store/base.py +++ b/pilot/vector_store/base.py @@ -15,9 +15,9 @@ class VectorStoreBase(ABC): pass @abstractmethod - def vector_name_exists(self, text, topk) -> None: + def vector_name_exists(self) -> bool: """is vector store name exist.""" - pass + return False @abstractmethod def delete_by_ids(self, ids): diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index d95fc4821..7ff563415 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,17 +1,9 @@ from pilot.vector_store.chroma_store import ChromaStore - -# from pilot.vector_store.weaviate_store import WeaviateStore +from pilot import vector_store +from pilot.vector_store.base import VectorStoreBase connector = {"Chroma": ChromaStore} -try: - from pilot.vector_store.milvus_store import MilvusStore - - connector["Milvus"] = MilvusStore -except: - pass - - class VectorStoreConnector: """VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1. 1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate) @@ -24,9 +16,16 @@ class VectorStoreConnector: def __init__(self, vector_store_type, ctx: {}) -> None: """initialize vector store connector.""" self.ctx = ctx - self.connector_class = connector[vector_store_type] + self._register() + + if self._match(vector_store_type): + self.connector_class = connector.get(vector_store_type) + else: + raise Exception(f"Vector Type Not support. {0}", vector_store_type) + self.client = self.connector_class(ctx) + def load_document(self, docs): """load document in vector database.""" return self.client.load_document(docs) @@ -46,3 +45,14 @@ class VectorStoreConnector: def delete_by_ids(self, ids): """vector store delete by ids.""" return self.client.delete_by_ids(ids=ids) + + def _match(self, vector_store_type) -> bool: + if connector.get(vector_store_type): + return True + else: + return False + + def _register(self): + for cls in vector_store.__all__: + if issubclass(getattr(vector_store, cls), VectorStoreBase): + connector.update({cls, getattr(vector_store, cls)}) \ No newline at end of file diff --git a/pilot/vector_store/pgvector_store.py b/pilot/vector_store/pgvector_store.py new file mode 100644 index 000000000..8155dd61c --- /dev/null +++ b/pilot/vector_store/pgvector_store.py @@ -0,0 +1,50 @@ +from typing import Any +import logging +from pilot.vector_store.base import VectorStoreBase + +logger = logging.getLogger(__name__) + +class PGVectorStore(VectorStoreBase): + """`Postgres.PGVector` vector store. + + To use this, you should have the ``pgvector`` python package installed. + """ + + def __init__(self, ctx: dict) -> None: + """init pgvector storage""" + + from langchain.vectorstores import PGVector + + self.ctx = ctx + self.connection_string = ctx.get("connection_string", None) + self.embeddings = ctx.get("embeddings", None) + self.collection_name = ctx.get("vector_store_name", None) + + self.vector_store_client = PGVector( + embedding_function=self.embeddings, + collection_name=self.collection_name, + connection_string=self.connection_string + ) + + def similar_search(self, text, topk, **kwargs: Any) -> None: + return self.vector_store_client.similarity_search(text, topk) + + + def vector_name_exists(self): + try: + self.vector_store_client.create_collection() + return True + except Exception as e: + logger.error("vector_name_exists error", e.message) + return False + + def load_document(self, documents) -> None: + return self.vector_store_client.from_documents(documents) + + + def delete_vector_name(self, vector_name): + return self.vector_store_client.delete_collection() + + + def delete_by_ids(self, ids): + return self.vector_store_client.delete(ids) \ No newline at end of file diff --git a/tests/intetration_tests/kbqa/__init__.py b/tests/intetration_tests/kbqa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/intetration_tests/vector_store/__init__.py b/tests/intetration_tests/vector_store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/embedding_engine/document_test.py b/tests/unit_tests/embedding_engine/document_test.py similarity index 100% rename from tests/unit/embedding_engine/document_test.py rename to tests/unit_tests/embedding_engine/document_test.py diff --git a/tests/unit/embedding_engine/url_test.py b/tests/unit_tests/embedding_engine/url_test.py similarity index 100% rename from tests/unit/embedding_engine/url_test.py rename to tests/unit_tests/embedding_engine/url_test.py diff --git a/tests/unit/test_plugins.py b/tests/unit_tests/test_plugins.py similarity index 100% rename from tests/unit/test_plugins.py rename to tests/unit_tests/test_plugins.py diff --git a/tests/unit_tests/vector_store/test_pgvector.py b/tests/unit_tests/vector_store/test_pgvector.py new file mode 100644 index 000000000..c96643683 --- /dev/null +++ b/tests/unit_tests/vector_store/test_pgvector.py @@ -0,0 +1,10 @@ +import pytest + +from pilot import vector_store +from pilot.vector_store.base import VectorStoreBase + +def test_vetorestore_imports() -> None: + """ Simple test to make sure all things can be imported.""" + + for cls in vector_store.__all__: + assert issubclass(getattr(vector_store, cls), VectorStoreBase)