fix:vectordb lazy load

This commit is contained in:
aries_ckt 2023-10-17 11:52:45 +08:00
parent 88f1daaa50
commit 496696537d
6 changed files with 47 additions and 36 deletions

View File

@ -1,21 +1,30 @@
from typing import Any from typing import Any
def _import_pgvector() -> Any: def _import_pgvector() -> Any:
from pilot.vector_store.pgvector_store import PGVectorStore from pilot.vector_store.pgvector_store import PGVectorStore
return PGVectorStore return PGVectorStore
def _import_milvus() -> Any: def _import_milvus() -> Any:
from pilot.vector_store.milvus_store import MilvusStore from pilot.vector_store.milvus_store import MilvusStore
return MilvusStore return MilvusStore
def _import_chroma() -> Any: def _import_chroma() -> Any:
from pilot.vector_store.chroma_store import ChromaStore from pilot.vector_store.chroma_store import ChromaStore
return ChromaStore return ChromaStore
def _import_weaviate() -> Any: def _import_weaviate() -> Any:
from pilot.vector_store.weaviate_store import WeaviateStore from pilot.vector_store.weaviate_store import WeaviateStore
return WeaviateStore return WeaviateStore
def __getattr__(name: str) -> Any: def __getattr__(name: str) -> Any:
if name == "Chroma": if name == "Chroma":
return _import_chroma() return _import_chroma()
@ -28,9 +37,5 @@ def __getattr__(name: str) -> Any:
else: else:
raise AttributeError(f"Could not find: {name}") raise AttributeError(f"Could not find: {name}")
__all__ = [
"Chroma", __all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]
"Milvus",
"Weaviate",
"PGVector"
]

View File

@ -17,7 +17,7 @@ class VectorStoreBase(ABC):
@abstractmethod @abstractmethod
def vector_name_exists(self) -> bool: def vector_name_exists(self) -> bool:
"""is vector store name exist.""" """is vector store name exist."""
return False return False
@abstractmethod @abstractmethod
def delete_by_ids(self, ids): def delete_by_ids(self, ids):

View File

@ -3,6 +3,7 @@ from pilot.vector_store.base import VectorStoreBase
connector = {} connector = {}
class VectorStoreConnector: class VectorStoreConnector:
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1. """VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate) 1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
@ -16,16 +17,15 @@ class VectorStoreConnector:
"""initialize vector store connector.""" """initialize vector store connector."""
self.ctx = ctx self.ctx = ctx
self._register() self._register()
if self._match(vector_store_type): if self._match(vector_store_type):
self.connector_class = connector.get(vector_store_type) self.connector_class = connector.get(vector_store_type)
else: else:
raise Exception(f"Vector Type Not support. {0}", vector_store_type) raise Exception(f"Vector Type Not support. {0}", vector_store_type)
print(self.connector_class) print(self.connector_class)
self.client = self.connector_class(ctx) self.client = self.connector_class(ctx)
def load_document(self, docs): def load_document(self, docs):
"""load document in vector database.""" """load document in vector database."""
return self.client.load_document(docs) return self.client.load_document(docs)
@ -51,9 +51,9 @@ class VectorStoreConnector:
return True return True
else: else:
return False return False
def _register(self): def _register(self):
for cls in vector_store.__all__: for cls in vector_store.__all__:
if issubclass(getattr(vector_store, cls), VectorStoreBase): if issubclass(getattr(vector_store, cls), VectorStoreBase):
_k, _v = cls, getattr(vector_store, cls) _k, _v = cls, getattr(vector_store, cls)
connector.update({_k: _v}) connector.update({_k: _v})

View File

@ -3,7 +3,6 @@ import logging
import os import os
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
from pymilvus import Collection, DataType, connections, utility
from pilot.vector_store.base import VectorStoreBase from pilot.vector_store.base import VectorStoreBase
@ -14,6 +13,8 @@ class MilvusStore(VectorStoreBase):
"""Milvus database""" """Milvus database"""
def __init__(self, ctx: {}) -> None: def __init__(self, ctx: {}) -> None:
from pymilvus import Collection, DataType, connections, utility
"""init a milvus storage connection. """init a milvus storage connection.
Args: Args:
@ -85,6 +86,7 @@ class MilvusStore(VectorStoreBase):
DataType, DataType,
FieldSchema, FieldSchema,
connections, connections,
utility,
) )
from pymilvus.orm.types import infer_dtype_bydata from pymilvus.orm.types import infer_dtype_bydata
except ImportError: except ImportError:
@ -260,6 +262,8 @@ class MilvusStore(VectorStoreBase):
return doc_ids return doc_ids
def similar_search(self, text, topk) -> None: def similar_search(self, text, topk) -> None:
from pymilvus import Collection, DataType
"""similar_search in vector database.""" """similar_search in vector database."""
self.col = Collection(self.collection_name) self.col = Collection(self.collection_name)
schema = self.col.schema schema = self.col.schema
@ -324,16 +328,22 @@ class MilvusStore(VectorStoreBase):
return data[0], ret return data[0], ret
def vector_name_exists(self): def vector_name_exists(self):
from pymilvus import utility
"""is vector store name exist.""" """is vector store name exist."""
return utility.has_collection(self.collection_name) return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name): def delete_vector_name(self, vector_name):
from pymilvus import utility
"""milvus delete collection name""" """milvus delete collection name"""
logger.info(f"milvus vector_name:{vector_name} begin delete...") logger.info(f"milvus vector_name:{vector_name} begin delete...")
utility.drop_collection(vector_name) utility.drop_collection(vector_name)
return True return True
def delete_by_ids(self, ids): def delete_by_ids(self, ids):
from pymilvus import Collection
self.col = Collection(self.collection_name) self.col = Collection(self.collection_name)
"""milvus delete vectors by ids""" """milvus delete vectors by ids"""
logger.info(f"begin delete milvus ids...") logger.info(f"begin delete milvus ids...")
@ -342,6 +352,3 @@ class MilvusStore(VectorStoreBase):
delet_expr = f"{self.primary_field} in {doc_ids}" delet_expr = f"{self.primary_field} in {doc_ids}"
self.col.delete(delet_expr) self.col.delete(delet_expr)
return True return True
def close(self):
connections.disconnect()

View File

@ -7,32 +7,32 @@ logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
class PGVectorStore(VectorStoreBase): class PGVectorStore(VectorStoreBase):
"""`Postgres.PGVector` vector store. """`Postgres.PGVector` vector store.
To use this, you should have the ``pgvector`` python package installed. To use this, you should have the ``pgvector`` python package installed.
""" """
def __init__(self, ctx: dict) -> None: def __init__(self, ctx: dict) -> None:
"""init pgvector storage""" """init pgvector storage"""
from langchain.vectorstores import PGVector from langchain.vectorstores import PGVector
self.ctx = ctx self.ctx = ctx
self.connection_string = ctx.get("connection_string", None) self.connection_string = ctx.get("connection_string", None)
self.embeddings = ctx.get("embeddings", None) self.embeddings = ctx.get("embeddings", None)
self.collection_name = ctx.get("vector_store_name", None) self.collection_name = ctx.get("vector_store_name", None)
self.vector_store_client = PGVector( self.vector_store_client = PGVector(
embedding_function=self.embeddings, embedding_function=self.embeddings,
collection_name=self.collection_name, collection_name=self.collection_name,
connection_string=self.connection_string connection_string=self.connection_string,
) )
def similar_search(self, text, topk, **kwargs: Any) -> None:
return self.vector_store_client.similarity_search(text, topk)
def similar_search(self, text, topk, **kwargs: Any) -> None:
return self.vector_store_client.similarity_search(text, topk)
def vector_name_exists(self): def vector_name_exists(self):
try: try:
self.vector_store_client.create_collection() self.vector_store_client.create_collection()
@ -40,14 +40,12 @@ class PGVectorStore(VectorStoreBase):
except Exception as e: except Exception as e:
logger.error("vector_name_exists error", e.message) logger.error("vector_name_exists error", e.message)
return False return False
def load_document(self, documents) -> None: def load_document(self, documents) -> None:
return self.vector_store_client.from_documents(documents) return self.vector_store_client.from_documents(documents)
def delete_vector_name(self, vector_name): def delete_vector_name(self, vector_name):
return self.vector_store_client.delete_collection() return self.vector_store_client.delete_collection()
def delete_by_ids(self, ids): def delete_by_ids(self, ids):
return self.vector_store_client.delete(ids) return self.vector_store_client.delete(ids)

View File

@ -3,8 +3,9 @@ import pytest
from pilot import vector_store from pilot import vector_store
from pilot.vector_store.base import VectorStoreBase from pilot.vector_store.base import VectorStoreBase
def test_vetorestore_imports() -> None:
""" Simple test to make sure all things can be imported."""
for cls in vector_store.__all__: def test_vetorestore_imports() -> None:
"""Simple test to make sure all things can be imported."""
for cls in vector_store.__all__:
assert issubclass(getattr(vector_store, cls), VectorStoreBase) assert issubclass(getattr(vector_store, cls), VectorStoreBase)