fix:vectordb lazy load

This commit is contained in:
aries_ckt 2023-10-17 11:52:45 +08:00
parent 88f1daaa50
commit 496696537d
6 changed files with 47 additions and 36 deletions

View File

@ -1,21 +1,30 @@
from typing import Any from typing import Any
def _import_pgvector() -> Any: def _import_pgvector() -> Any:
from pilot.vector_store.pgvector_store import PGVectorStore from pilot.vector_store.pgvector_store import PGVectorStore
return PGVectorStore return PGVectorStore
def _import_milvus() -> Any: def _import_milvus() -> Any:
from pilot.vector_store.milvus_store import MilvusStore from pilot.vector_store.milvus_store import MilvusStore
return MilvusStore return MilvusStore
def _import_chroma() -> Any: def _import_chroma() -> Any:
from pilot.vector_store.chroma_store import ChromaStore from pilot.vector_store.chroma_store import ChromaStore
return ChromaStore return ChromaStore
def _import_weaviate() -> Any: def _import_weaviate() -> Any:
from pilot.vector_store.weaviate_store import WeaviateStore from pilot.vector_store.weaviate_store import WeaviateStore
return WeaviateStore return WeaviateStore
def __getattr__(name: str) -> Any: def __getattr__(name: str) -> Any:
if name == "Chroma": if name == "Chroma":
return _import_chroma() return _import_chroma()
@ -28,9 +37,5 @@ def __getattr__(name: str) -> Any:
else: else:
raise AttributeError(f"Could not find: {name}") raise AttributeError(f"Could not find: {name}")
__all__ = [
"Chroma", __all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]
"Milvus",
"Weaviate",
"PGVector"
]

View File

@ -3,6 +3,7 @@ from pilot.vector_store.base import VectorStoreBase
connector = {} connector = {}
class VectorStoreConnector: class VectorStoreConnector:
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1. """VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate) 1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
@ -25,7 +26,6 @@ class VectorStoreConnector:
print(self.connector_class) print(self.connector_class)
self.client = self.connector_class(ctx) self.client = self.connector_class(ctx)
def load_document(self, docs): def load_document(self, docs):
"""load document in vector database.""" """load document in vector database."""
return self.client.load_document(docs) return self.client.load_document(docs)

View File

@ -3,7 +3,6 @@ import logging
import os import os
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
from pymilvus import Collection, DataType, connections, utility
from pilot.vector_store.base import VectorStoreBase from pilot.vector_store.base import VectorStoreBase
@ -14,6 +13,8 @@ class MilvusStore(VectorStoreBase):
"""Milvus database""" """Milvus database"""
def __init__(self, ctx: {}) -> None: def __init__(self, ctx: {}) -> None:
from pymilvus import Collection, DataType, connections, utility
"""init a milvus storage connection. """init a milvus storage connection.
Args: Args:
@ -85,6 +86,7 @@ class MilvusStore(VectorStoreBase):
DataType, DataType,
FieldSchema, FieldSchema,
connections, connections,
utility,
) )
from pymilvus.orm.types import infer_dtype_bydata from pymilvus.orm.types import infer_dtype_bydata
except ImportError: except ImportError:
@ -260,6 +262,8 @@ class MilvusStore(VectorStoreBase):
return doc_ids return doc_ids
def similar_search(self, text, topk) -> None: def similar_search(self, text, topk) -> None:
from pymilvus import Collection, DataType
"""similar_search in vector database.""" """similar_search in vector database."""
self.col = Collection(self.collection_name) self.col = Collection(self.collection_name)
schema = self.col.schema schema = self.col.schema
@ -324,16 +328,22 @@ class MilvusStore(VectorStoreBase):
return data[0], ret return data[0], ret
def vector_name_exists(self): def vector_name_exists(self):
from pymilvus import utility
"""is vector store name exist.""" """is vector store name exist."""
return utility.has_collection(self.collection_name) return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name): def delete_vector_name(self, vector_name):
from pymilvus import utility
"""milvus delete collection name""" """milvus delete collection name"""
logger.info(f"milvus vector_name:{vector_name} begin delete...") logger.info(f"milvus vector_name:{vector_name} begin delete...")
utility.drop_collection(vector_name) utility.drop_collection(vector_name)
return True return True
def delete_by_ids(self, ids): def delete_by_ids(self, ids):
from pymilvus import Collection
self.col = Collection(self.collection_name) self.col = Collection(self.collection_name)
"""milvus delete vectors by ids""" """milvus delete vectors by ids"""
logger.info(f"begin delete milvus ids...") logger.info(f"begin delete milvus ids...")
@ -342,6 +352,3 @@ class MilvusStore(VectorStoreBase):
delet_expr = f"{self.primary_field} in {doc_ids}" delet_expr = f"{self.primary_field} in {doc_ids}"
self.col.delete(delet_expr) self.col.delete(delet_expr)
return True return True
def close(self):
connections.disconnect()

View File

@ -7,6 +7,7 @@ logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
class PGVectorStore(VectorStoreBase): class PGVectorStore(VectorStoreBase):
"""`Postgres.PGVector` vector store. """`Postgres.PGVector` vector store.
@ -26,13 +27,12 @@ class PGVectorStore(VectorStoreBase):
self.vector_store_client = PGVector( self.vector_store_client = PGVector(
embedding_function=self.embeddings, embedding_function=self.embeddings,
collection_name=self.collection_name, collection_name=self.collection_name,
connection_string=self.connection_string connection_string=self.connection_string,
) )
def similar_search(self, text, topk, **kwargs: Any) -> None: def similar_search(self, text, topk, **kwargs: Any) -> None:
return self.vector_store_client.similarity_search(text, topk) return self.vector_store_client.similarity_search(text, topk)
def vector_name_exists(self): def vector_name_exists(self):
try: try:
self.vector_store_client.create_collection() self.vector_store_client.create_collection()
@ -44,10 +44,8 @@ class PGVectorStore(VectorStoreBase):
def load_document(self, documents) -> None: def load_document(self, documents) -> None:
return self.vector_store_client.from_documents(documents) return self.vector_store_client.from_documents(documents)
def delete_vector_name(self, vector_name): def delete_vector_name(self, vector_name):
return self.vector_store_client.delete_collection() return self.vector_store_client.delete_collection()
def delete_by_ids(self, ids): def delete_by_ids(self, ids):
return self.vector_store_client.delete(ids) return self.vector_store_client.delete(ids)

View File

@ -3,6 +3,7 @@ import pytest
from pilot import vector_store from pilot import vector_store
from pilot.vector_store.base import VectorStoreBase from pilot.vector_store.base import VectorStoreBase
def test_vetorestore_imports() -> None: def test_vetorestore_imports() -> None:
"""Simple test to make sure all things can be imported.""" """Simple test to make sure all things can be imported."""