fix:vectordb lazy load

This commit is contained in:
aries_ckt 2023-10-17 11:52:45 +08:00
parent 88f1daaa50
commit 496696537d
6 changed files with 47 additions and 36 deletions

View File

@ -1,21 +1,30 @@
from typing import Any
def _import_pgvector() -> Any:
from pilot.vector_store.pgvector_store import PGVectorStore
from pilot.vector_store.pgvector_store import PGVectorStore
return PGVectorStore
def _import_milvus() -> Any:
from pilot.vector_store.milvus_store import MilvusStore
return MilvusStore
def _import_chroma() -> Any:
from pilot.vector_store.chroma_store import ChromaStore
return ChromaStore
def _import_weaviate() -> Any:
from pilot.vector_store.weaviate_store import WeaviateStore
return WeaviateStore
def __getattr__(name: str) -> Any:
if name == "Chroma":
return _import_chroma()
@ -28,9 +37,5 @@ def __getattr__(name: str) -> Any:
else:
raise AttributeError(f"Could not find: {name}")
__all__ = [
"Chroma",
"Milvus",
"Weaviate",
"PGVector"
]
__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]

View File

@ -17,7 +17,7 @@ class VectorStoreBase(ABC):
@abstractmethod
def vector_name_exists(self) -> bool:
"""is vector store name exist."""
return False
return False
@abstractmethod
def delete_by_ids(self, ids):

View File

@ -3,6 +3,7 @@ from pilot.vector_store.base import VectorStoreBase
connector = {}
class VectorStoreConnector:
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
@ -16,16 +17,15 @@ class VectorStoreConnector:
"""initialize vector store connector."""
self.ctx = ctx
self._register()
if self._match(vector_store_type):
self.connector_class = connector.get(vector_store_type)
else:
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
print(self.connector_class)
print(self.connector_class)
self.client = self.connector_class(ctx)
def load_document(self, docs):
"""load document in vector database."""
return self.client.load_document(docs)
@ -51,9 +51,9 @@ class VectorStoreConnector:
return True
else:
return False
def _register(self):
for cls in vector_store.__all__:
if issubclass(getattr(vector_store, cls), VectorStoreBase):
_k, _v = cls, getattr(vector_store, cls)
connector.update({_k: _v})
connector.update({_k: _v})

View File

@ -3,7 +3,6 @@ import logging
import os
from typing import Any, Iterable, List, Optional, Tuple
from pymilvus import Collection, DataType, connections, utility
from pilot.vector_store.base import VectorStoreBase
@ -14,6 +13,8 @@ class MilvusStore(VectorStoreBase):
"""Milvus database"""
def __init__(self, ctx: {}) -> None:
from pymilvus import Collection, DataType, connections, utility
"""init a milvus storage connection.
Args:
@ -85,6 +86,7 @@ class MilvusStore(VectorStoreBase):
DataType,
FieldSchema,
connections,
utility,
)
from pymilvus.orm.types import infer_dtype_bydata
except ImportError:
@ -260,6 +262,8 @@ class MilvusStore(VectorStoreBase):
return doc_ids
def similar_search(self, text, topk) -> None:
from pymilvus import Collection, DataType
"""similar_search in vector database."""
self.col = Collection(self.collection_name)
schema = self.col.schema
@ -324,16 +328,22 @@ class MilvusStore(VectorStoreBase):
return data[0], ret
def vector_name_exists(self):
from pymilvus import utility
"""is vector store name exist."""
return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name):
from pymilvus import utility
"""milvus delete collection name"""
logger.info(f"milvus vector_name:{vector_name} begin delete...")
utility.drop_collection(vector_name)
return True
def delete_by_ids(self, ids):
from pymilvus import Collection
self.col = Collection(self.collection_name)
"""milvus delete vectors by ids"""
logger.info(f"begin delete milvus ids...")
@ -342,6 +352,3 @@ class MilvusStore(VectorStoreBase):
delet_expr = f"{self.primary_field} in {doc_ids}"
self.col.delete(delet_expr)
return True
def close(self):
connections.disconnect()

View File

@ -7,32 +7,32 @@ logger = logging.getLogger(__name__)
CFG = Config()
class PGVectorStore(VectorStoreBase):
"""`Postgres.PGVector` vector store.
"""`Postgres.PGVector` vector store.
To use this, you should have the ``pgvector`` python package installed.
"""
def __init__(self, ctx: dict) -> None:
"""init pgvector storage"""
from langchain.vectorstores import PGVector
self.ctx = ctx
self.connection_string = ctx.get("connection_string", None)
self.embeddings = ctx.get("embeddings", None)
self.collection_name = ctx.get("vector_store_name", None)
self.vector_store_client = PGVector(
embedding_function=self.embeddings,
collection_name=self.collection_name,
connection_string=self.connection_string
connection_string=self.connection_string,
)
def similar_search(self, text, topk, **kwargs: Any) -> None:
return self.vector_store_client.similarity_search(text, topk)
def similar_search(self, text, topk, **kwargs: Any) -> None:
return self.vector_store_client.similarity_search(text, topk)
def vector_name_exists(self):
try:
self.vector_store_client.create_collection()
@ -40,14 +40,12 @@ class PGVectorStore(VectorStoreBase):
except Exception as e:
logger.error("vector_name_exists error", e.message)
return False
def load_document(self, documents) -> None:
return self.vector_store_client.from_documents(documents)
def delete_vector_name(self, vector_name):
return self.vector_store_client.delete_collection()
return self.vector_store_client.delete_collection()
def delete_by_ids(self, ids):
return self.vector_store_client.delete(ids)
return self.vector_store_client.delete(ids)

View File

@ -3,8 +3,9 @@ import pytest
from pilot import vector_store
from pilot.vector_store.base import VectorStoreBase
def test_vetorestore_imports() -> None:
""" Simple test to make sure all things can be imported."""
for cls in vector_store.__all__:
def test_vetorestore_imports() -> None:
"""Simple test to make sure all things can be imported."""
for cls in vector_store.__all__:
assert issubclass(getattr(vector_store, cls), VectorStoreBase)