mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-04 02:25:08 +00:00
feat: add pgvector vectorstore
This commit is contained in:
@@ -0,0 +1,36 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
def _import_pgvector() -> Any:
|
||||||
|
from pilot.vector_store.pgvector_store import PGVectorStore
|
||||||
|
return PGVectorStore
|
||||||
|
|
||||||
|
def _import_milvus() -> Any:
|
||||||
|
from pilot.vector_store.milvus_store import MilvusStore
|
||||||
|
return MilvusStore
|
||||||
|
|
||||||
|
def _import_chroma() -> Any:
|
||||||
|
from pilot.vector_store.chroma_store import ChromaStore
|
||||||
|
return ChromaStore
|
||||||
|
|
||||||
|
def _import_weaviate() -> Any:
|
||||||
|
from pilot.vector_store.weaviate_store import WeaviateStore
|
||||||
|
return WeaviateStore
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
if name == "Chroma":
|
||||||
|
return _import_chroma()
|
||||||
|
elif name == "Milvus":
|
||||||
|
return _import_milvus()
|
||||||
|
elif name == "Weaviate":
|
||||||
|
return _import_weaviate()
|
||||||
|
elif name == "PGVector":
|
||||||
|
return _import_pgvector()
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"Could not find: {name}")
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Chroma",
|
||||||
|
"Milvus",
|
||||||
|
"Weaviate",
|
||||||
|
"PGVector"
|
||||||
|
]
|
@@ -15,9 +15,9 @@ class VectorStoreBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def vector_name_exists(self, text, topk) -> None:
|
def vector_name_exists(self) -> bool:
|
||||||
"""is vector store name exist."""
|
"""is vector store name exist."""
|
||||||
pass
|
return False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_ids(self, ids):
|
def delete_by_ids(self, ids):
|
||||||
|
@@ -1,17 +1,9 @@
|
|||||||
from pilot.vector_store.chroma_store import ChromaStore
|
from pilot.vector_store.chroma_store import ChromaStore
|
||||||
|
from pilot import vector_store
|
||||||
# from pilot.vector_store.weaviate_store import WeaviateStore
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
connector = {"Chroma": ChromaStore}
|
connector = {"Chroma": ChromaStore}
|
||||||
|
|
||||||
try:
|
|
||||||
from pilot.vector_store.milvus_store import MilvusStore
|
|
||||||
|
|
||||||
connector["Milvus"] = MilvusStore
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConnector:
|
class VectorStoreConnector:
|
||||||
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
|
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
|
||||||
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
||||||
@@ -24,9 +16,16 @@ class VectorStoreConnector:
|
|||||||
def __init__(self, vector_store_type, ctx: {}) -> None:
|
def __init__(self, vector_store_type, ctx: {}) -> None:
|
||||||
"""initialize vector store connector."""
|
"""initialize vector store connector."""
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
self.connector_class = connector[vector_store_type]
|
self._register()
|
||||||
|
|
||||||
|
if self._match(vector_store_type):
|
||||||
|
self.connector_class = connector.get(vector_store_type)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
|
||||||
|
|
||||||
self.client = self.connector_class(ctx)
|
self.client = self.connector_class(ctx)
|
||||||
|
|
||||||
|
|
||||||
def load_document(self, docs):
|
def load_document(self, docs):
|
||||||
"""load document in vector database."""
|
"""load document in vector database."""
|
||||||
return self.client.load_document(docs)
|
return self.client.load_document(docs)
|
||||||
@@ -46,3 +45,14 @@ class VectorStoreConnector:
|
|||||||
def delete_by_ids(self, ids):
|
def delete_by_ids(self, ids):
|
||||||
"""vector store delete by ids."""
|
"""vector store delete by ids."""
|
||||||
return self.client.delete_by_ids(ids=ids)
|
return self.client.delete_by_ids(ids=ids)
|
||||||
|
|
||||||
|
def _match(self, vector_store_type) -> bool:
|
||||||
|
if connector.get(vector_store_type):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _register(self):
|
||||||
|
for cls in vector_store.__all__:
|
||||||
|
if issubclass(getattr(vector_store, cls), VectorStoreBase):
|
||||||
|
connector.update({cls, getattr(vector_store, cls)})
|
50
pilot/vector_store/pgvector_store.py
Normal file
50
pilot/vector_store/pgvector_store.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from typing import Any
|
||||||
|
import logging
|
||||||
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class PGVectorStore(VectorStoreBase):
|
||||||
|
"""`Postgres.PGVector` vector store.
|
||||||
|
|
||||||
|
To use this, you should have the ``pgvector`` python package installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ctx: dict) -> None:
|
||||||
|
"""init pgvector storage"""
|
||||||
|
|
||||||
|
from langchain.vectorstores import PGVector
|
||||||
|
|
||||||
|
self.ctx = ctx
|
||||||
|
self.connection_string = ctx.get("connection_string", None)
|
||||||
|
self.embeddings = ctx.get("embeddings", None)
|
||||||
|
self.collection_name = ctx.get("vector_store_name", None)
|
||||||
|
|
||||||
|
self.vector_store_client = PGVector(
|
||||||
|
embedding_function=self.embeddings,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
connection_string=self.connection_string
|
||||||
|
)
|
||||||
|
|
||||||
|
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||||
|
return self.vector_store_client.similarity_search(text, topk)
|
||||||
|
|
||||||
|
|
||||||
|
def vector_name_exists(self):
|
||||||
|
try:
|
||||||
|
self.vector_store_client.create_collection()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("vector_name_exists error", e.message)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_document(self, documents) -> None:
|
||||||
|
return self.vector_store_client.from_documents(documents)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_vector_name(self, vector_name):
|
||||||
|
return self.vector_store_client.delete_collection()
|
||||||
|
|
||||||
|
|
||||||
|
def delete_by_ids(self, ids):
|
||||||
|
return self.vector_store_client.delete(ids)
|
0
tests/intetration_tests/kbqa/__init__.py
Normal file
0
tests/intetration_tests/kbqa/__init__.py
Normal file
0
tests/intetration_tests/vector_store/__init__.py
Normal file
0
tests/intetration_tests/vector_store/__init__.py
Normal file
10
tests/unit_tests/vector_store/test_pgvector.py
Normal file
10
tests/unit_tests/vector_store/test_pgvector.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from pilot import vector_store
|
||||||
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
|
def test_vetorestore_imports() -> None:
|
||||||
|
""" Simple test to make sure all things can be imported."""
|
||||||
|
|
||||||
|
for cls in vector_store.__all__:
|
||||||
|
assert issubclass(getattr(vector_store, cls), VectorStoreBase)
|
Reference in New Issue
Block a user