mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-06 10:54:29 +00:00
fix:vectordb lazy load
This commit is contained in:
parent
88f1daaa50
commit
496696537d
@ -1,21 +1,30 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def _import_pgvector() -> Any:
|
def _import_pgvector() -> Any:
|
||||||
from pilot.vector_store.pgvector_store import PGVectorStore
|
from pilot.vector_store.pgvector_store import PGVectorStore
|
||||||
|
|
||||||
return PGVectorStore
|
return PGVectorStore
|
||||||
|
|
||||||
|
|
||||||
def _import_milvus() -> Any:
|
def _import_milvus() -> Any:
|
||||||
from pilot.vector_store.milvus_store import MilvusStore
|
from pilot.vector_store.milvus_store import MilvusStore
|
||||||
|
|
||||||
return MilvusStore
|
return MilvusStore
|
||||||
|
|
||||||
|
|
||||||
def _import_chroma() -> Any:
|
def _import_chroma() -> Any:
|
||||||
from pilot.vector_store.chroma_store import ChromaStore
|
from pilot.vector_store.chroma_store import ChromaStore
|
||||||
|
|
||||||
return ChromaStore
|
return ChromaStore
|
||||||
|
|
||||||
|
|
||||||
def _import_weaviate() -> Any:
|
def _import_weaviate() -> Any:
|
||||||
from pilot.vector_store.weaviate_store import WeaviateStore
|
from pilot.vector_store.weaviate_store import WeaviateStore
|
||||||
|
|
||||||
return WeaviateStore
|
return WeaviateStore
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> Any:
|
def __getattr__(name: str) -> Any:
|
||||||
if name == "Chroma":
|
if name == "Chroma":
|
||||||
return _import_chroma()
|
return _import_chroma()
|
||||||
@ -28,9 +37,5 @@ def __getattr__(name: str) -> Any:
|
|||||||
else:
|
else:
|
||||||
raise AttributeError(f"Could not find: {name}")
|
raise AttributeError(f"Could not find: {name}")
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Chroma",
|
__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]
|
||||||
"Milvus",
|
|
||||||
"Weaviate",
|
|
||||||
"PGVector"
|
|
||||||
]
|
|
||||||
|
@ -17,7 +17,7 @@ class VectorStoreBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def vector_name_exists(self) -> bool:
|
def vector_name_exists(self) -> bool:
|
||||||
"""is vector store name exist."""
|
"""is vector store name exist."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_ids(self, ids):
|
def delete_by_ids(self, ids):
|
||||||
|
@ -3,6 +3,7 @@ from pilot.vector_store.base import VectorStoreBase
|
|||||||
|
|
||||||
connector = {}
|
connector = {}
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConnector:
|
class VectorStoreConnector:
|
||||||
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
|
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
|
||||||
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
||||||
@ -16,16 +17,15 @@ class VectorStoreConnector:
|
|||||||
"""initialize vector store connector."""
|
"""initialize vector store connector."""
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self._register()
|
self._register()
|
||||||
|
|
||||||
if self._match(vector_store_type):
|
if self._match(vector_store_type):
|
||||||
self.connector_class = connector.get(vector_store_type)
|
self.connector_class = connector.get(vector_store_type)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
|
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
|
||||||
|
|
||||||
print(self.connector_class)
|
print(self.connector_class)
|
||||||
self.client = self.connector_class(ctx)
|
self.client = self.connector_class(ctx)
|
||||||
|
|
||||||
|
|
||||||
def load_document(self, docs):
|
def load_document(self, docs):
|
||||||
"""load document in vector database."""
|
"""load document in vector database."""
|
||||||
return self.client.load_document(docs)
|
return self.client.load_document(docs)
|
||||||
@ -51,9 +51,9 @@ class VectorStoreConnector:
|
|||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _register(self):
|
def _register(self):
|
||||||
for cls in vector_store.__all__:
|
for cls in vector_store.__all__:
|
||||||
if issubclass(getattr(vector_store, cls), VectorStoreBase):
|
if issubclass(getattr(vector_store, cls), VectorStoreBase):
|
||||||
_k, _v = cls, getattr(vector_store, cls)
|
_k, _v = cls, getattr(vector_store, cls)
|
||||||
connector.update({_k: _v})
|
connector.update({_k: _v})
|
||||||
|
@ -3,7 +3,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Iterable, List, Optional, Tuple
|
from typing import Any, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from pymilvus import Collection, DataType, connections, utility
|
|
||||||
|
|
||||||
from pilot.vector_store.base import VectorStoreBase
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
@ -14,6 +13,8 @@ class MilvusStore(VectorStoreBase):
|
|||||||
"""Milvus database"""
|
"""Milvus database"""
|
||||||
|
|
||||||
def __init__(self, ctx: {}) -> None:
|
def __init__(self, ctx: {}) -> None:
|
||||||
|
from pymilvus import Collection, DataType, connections, utility
|
||||||
|
|
||||||
"""init a milvus storage connection.
|
"""init a milvus storage connection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -85,6 +86,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
DataType,
|
DataType,
|
||||||
FieldSchema,
|
FieldSchema,
|
||||||
connections,
|
connections,
|
||||||
|
utility,
|
||||||
)
|
)
|
||||||
from pymilvus.orm.types import infer_dtype_bydata
|
from pymilvus.orm.types import infer_dtype_bydata
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -260,6 +262,8 @@ class MilvusStore(VectorStoreBase):
|
|||||||
return doc_ids
|
return doc_ids
|
||||||
|
|
||||||
def similar_search(self, text, topk) -> None:
|
def similar_search(self, text, topk) -> None:
|
||||||
|
from pymilvus import Collection, DataType
|
||||||
|
|
||||||
"""similar_search in vector database."""
|
"""similar_search in vector database."""
|
||||||
self.col = Collection(self.collection_name)
|
self.col = Collection(self.collection_name)
|
||||||
schema = self.col.schema
|
schema = self.col.schema
|
||||||
@ -324,16 +328,22 @@ class MilvusStore(VectorStoreBase):
|
|||||||
return data[0], ret
|
return data[0], ret
|
||||||
|
|
||||||
def vector_name_exists(self):
|
def vector_name_exists(self):
|
||||||
|
from pymilvus import utility
|
||||||
|
|
||||||
"""is vector store name exist."""
|
"""is vector store name exist."""
|
||||||
return utility.has_collection(self.collection_name)
|
return utility.has_collection(self.collection_name)
|
||||||
|
|
||||||
def delete_vector_name(self, vector_name):
|
def delete_vector_name(self, vector_name):
|
||||||
|
from pymilvus import utility
|
||||||
|
|
||||||
"""milvus delete collection name"""
|
"""milvus delete collection name"""
|
||||||
logger.info(f"milvus vector_name:{vector_name} begin delete...")
|
logger.info(f"milvus vector_name:{vector_name} begin delete...")
|
||||||
utility.drop_collection(vector_name)
|
utility.drop_collection(vector_name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_by_ids(self, ids):
|
def delete_by_ids(self, ids):
|
||||||
|
from pymilvus import Collection
|
||||||
|
|
||||||
self.col = Collection(self.collection_name)
|
self.col = Collection(self.collection_name)
|
||||||
"""milvus delete vectors by ids"""
|
"""milvus delete vectors by ids"""
|
||||||
logger.info(f"begin delete milvus ids...")
|
logger.info(f"begin delete milvus ids...")
|
||||||
@ -342,6 +352,3 @@ class MilvusStore(VectorStoreBase):
|
|||||||
delet_expr = f"{self.primary_field} in {doc_ids}"
|
delet_expr = f"{self.primary_field} in {doc_ids}"
|
||||||
self.col.delete(delet_expr)
|
self.col.delete(delet_expr)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def close(self):
|
|
||||||
connections.disconnect()
|
|
||||||
|
@ -7,32 +7,32 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class PGVectorStore(VectorStoreBase):
|
class PGVectorStore(VectorStoreBase):
|
||||||
"""`Postgres.PGVector` vector store.
|
"""`Postgres.PGVector` vector store.
|
||||||
|
|
||||||
To use this, you should have the ``pgvector`` python package installed.
|
To use this, you should have the ``pgvector`` python package installed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ctx: dict) -> None:
|
def __init__(self, ctx: dict) -> None:
|
||||||
"""init pgvector storage"""
|
"""init pgvector storage"""
|
||||||
|
|
||||||
from langchain.vectorstores import PGVector
|
from langchain.vectorstores import PGVector
|
||||||
|
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.connection_string = ctx.get("connection_string", None)
|
self.connection_string = ctx.get("connection_string", None)
|
||||||
self.embeddings = ctx.get("embeddings", None)
|
self.embeddings = ctx.get("embeddings", None)
|
||||||
self.collection_name = ctx.get("vector_store_name", None)
|
self.collection_name = ctx.get("vector_store_name", None)
|
||||||
|
|
||||||
self.vector_store_client = PGVector(
|
self.vector_store_client = PGVector(
|
||||||
embedding_function=self.embeddings,
|
embedding_function=self.embeddings,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
connection_string=self.connection_string
|
connection_string=self.connection_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
|
||||||
return self.vector_store_client.similarity_search(text, topk)
|
|
||||||
|
|
||||||
|
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||||
|
return self.vector_store_client.similarity_search(text, topk)
|
||||||
|
|
||||||
def vector_name_exists(self):
|
def vector_name_exists(self):
|
||||||
try:
|
try:
|
||||||
self.vector_store_client.create_collection()
|
self.vector_store_client.create_collection()
|
||||||
@ -40,14 +40,12 @@ class PGVectorStore(VectorStoreBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("vector_name_exists error", e.message)
|
logger.error("vector_name_exists error", e.message)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_document(self, documents) -> None:
|
def load_document(self, documents) -> None:
|
||||||
return self.vector_store_client.from_documents(documents)
|
return self.vector_store_client.from_documents(documents)
|
||||||
|
|
||||||
|
|
||||||
def delete_vector_name(self, vector_name):
|
def delete_vector_name(self, vector_name):
|
||||||
return self.vector_store_client.delete_collection()
|
return self.vector_store_client.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
def delete_by_ids(self, ids):
|
def delete_by_ids(self, ids):
|
||||||
return self.vector_store_client.delete(ids)
|
return self.vector_store_client.delete(ids)
|
||||||
|
@ -3,8 +3,9 @@ import pytest
|
|||||||
from pilot import vector_store
|
from pilot import vector_store
|
||||||
from pilot.vector_store.base import VectorStoreBase
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
def test_vetorestore_imports() -> None:
|
|
||||||
""" Simple test to make sure all things can be imported."""
|
|
||||||
|
|
||||||
for cls in vector_store.__all__:
|
def test_vetorestore_imports() -> None:
|
||||||
|
"""Simple test to make sure all things can be imported."""
|
||||||
|
|
||||||
|
for cls in vector_store.__all__:
|
||||||
assert issubclass(getattr(vector_store, cls), VectorStoreBase)
|
assert issubclass(getattr(vector_store, cls), VectorStoreBase)
|
||||||
|
Loading…
Reference in New Issue
Block a user