mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-31 16:39:48 +00:00
feat: add pgvector vectorstore
This commit is contained in:
@@ -0,0 +1,36 @@
|
||||
from typing import Any
|
||||
|
||||
def _import_pgvector() -> Any:
|
||||
from pilot.vector_store.pgvector_store import PGVectorStore
|
||||
return PGVectorStore
|
||||
|
||||
def _import_milvus() -> Any:
|
||||
from pilot.vector_store.milvus_store import MilvusStore
|
||||
return MilvusStore
|
||||
|
||||
def _import_chroma() -> Any:
|
||||
from pilot.vector_store.chroma_store import ChromaStore
|
||||
return ChromaStore
|
||||
|
||||
def _import_weaviate() -> Any:
|
||||
from pilot.vector_store.weaviate_store import WeaviateStore
|
||||
return WeaviateStore
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "Chroma":
|
||||
return _import_chroma()
|
||||
elif name == "Milvus":
|
||||
return _import_milvus()
|
||||
elif name == "Weaviate":
|
||||
return _import_weaviate()
|
||||
elif name == "PGVector":
|
||||
return _import_pgvector()
|
||||
else:
|
||||
raise AttributeError(f"Could not find: {name}")
|
||||
|
||||
__all__ = [
|
||||
"Chroma",
|
||||
"Milvus",
|
||||
"Weaviate",
|
||||
"PGVector"
|
||||
]
|
@@ -15,9 +15,9 @@ class VectorStoreBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def vector_name_exists(self, text, topk) -> None:
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""is vector store name exist."""
|
||||
pass
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids):
|
||||
|
@@ -1,17 +1,9 @@
|
||||
from pilot.vector_store.chroma_store import ChromaStore
|
||||
|
||||
# from pilot.vector_store.weaviate_store import WeaviateStore
|
||||
from pilot import vector_store
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
connector = {"Chroma": ChromaStore}
|
||||
|
||||
try:
|
||||
from pilot.vector_store.milvus_store import MilvusStore
|
||||
|
||||
connector["Milvus"] = MilvusStore
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class VectorStoreConnector:
|
||||
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
|
||||
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
||||
@@ -24,9 +16,16 @@ class VectorStoreConnector:
|
||||
def __init__(self, vector_store_type, ctx: {}) -> None:
|
||||
"""initialize vector store connector."""
|
||||
self.ctx = ctx
|
||||
self.connector_class = connector[vector_store_type]
|
||||
self._register()
|
||||
|
||||
if self._match(vector_store_type):
|
||||
self.connector_class = connector.get(vector_store_type)
|
||||
else:
|
||||
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
|
||||
|
||||
self.client = self.connector_class(ctx)
|
||||
|
||||
|
||||
def load_document(self, docs):
|
||||
"""load document in vector database."""
|
||||
return self.client.load_document(docs)
|
||||
@@ -46,3 +45,14 @@ class VectorStoreConnector:
|
||||
def delete_by_ids(self, ids):
|
||||
"""vector store delete by ids."""
|
||||
return self.client.delete_by_ids(ids=ids)
|
||||
|
||||
def _match(self, vector_store_type) -> bool:
|
||||
if connector.get(vector_store_type):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _register(self):
|
||||
for cls in vector_store.__all__:
|
||||
if issubclass(getattr(vector_store, cls), VectorStoreBase):
|
||||
connector.update({cls, getattr(vector_store, cls)})
|
50
pilot/vector_store/pgvector_store.py
Normal file
50
pilot/vector_store/pgvector_store.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Any
|
||||
import logging
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PGVectorStore(VectorStoreBase):
|
||||
"""`Postgres.PGVector` vector store.
|
||||
|
||||
To use this, you should have the ``pgvector`` python package installed.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: dict) -> None:
|
||||
"""init pgvector storage"""
|
||||
|
||||
from langchain.vectorstores import PGVector
|
||||
|
||||
self.ctx = ctx
|
||||
self.connection_string = ctx.get("connection_string", None)
|
||||
self.embeddings = ctx.get("embeddings", None)
|
||||
self.collection_name = ctx.get("vector_store_name", None)
|
||||
|
||||
self.vector_store_client = PGVector(
|
||||
embedding_function=self.embeddings,
|
||||
collection_name=self.collection_name,
|
||||
connection_string=self.connection_string
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||
return self.vector_store_client.similarity_search(text, topk)
|
||||
|
||||
|
||||
def vector_name_exists(self):
|
||||
try:
|
||||
self.vector_store_client.create_collection()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("vector_name_exists error", e.message)
|
||||
return False
|
||||
|
||||
def load_document(self, documents) -> None:
|
||||
return self.vector_store_client.from_documents(documents)
|
||||
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
return self.vector_store_client.delete_collection()
|
||||
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
return self.vector_store_client.delete(ids)
|
0
tests/intetration_tests/kbqa/__init__.py
Normal file
0
tests/intetration_tests/kbqa/__init__.py
Normal file
0
tests/intetration_tests/vector_store/__init__.py
Normal file
0
tests/intetration_tests/vector_store/__init__.py
Normal file
10
tests/unit_tests/vector_store/test_pgvector.py
Normal file
10
tests/unit_tests/vector_store/test_pgvector.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
|
||||
from pilot import vector_store
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
def test_vetorestore_imports() -> None:
|
||||
""" Simple test to make sure all things can be imported."""
|
||||
|
||||
for cls in vector_store.__all__:
|
||||
assert issubclass(getattr(vector_store, cls), VectorStoreBase)
|
Reference in New Issue
Block a user