mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 13:00:02 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
41
dbgpt/storage/vector_store/__init__.py
Normal file
41
dbgpt/storage/vector_store/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _import_pgvector() -> Any:
|
||||
from dbgpt.storage.vector_store.pgvector_store import PGVectorStore
|
||||
|
||||
return PGVectorStore
|
||||
|
||||
|
||||
def _import_milvus() -> Any:
|
||||
from dbgpt.storage.vector_store.milvus_store import MilvusStore
|
||||
|
||||
return MilvusStore
|
||||
|
||||
|
||||
def _import_chroma() -> Any:
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
|
||||
return ChromaStore
|
||||
|
||||
|
||||
def _import_weaviate() -> Any:
|
||||
from dbgpt.storage.vector_store.weaviate_store import WeaviateStore
|
||||
|
||||
return WeaviateStore
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "Chroma":
|
||||
return _import_chroma()
|
||||
elif name == "Milvus":
|
||||
return _import_milvus()
|
||||
elif name == "Weaviate":
|
||||
return _import_weaviate()
|
||||
elif name == "PGVector":
|
||||
return _import_pgvector()
|
||||
else:
|
||||
raise AttributeError(f"Could not find: {name}")
|
||||
|
||||
|
||||
__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]
|
42
dbgpt/storage/vector_store/base.py
Normal file
42
dbgpt/storage/vector_store/base.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import math
|
||||
|
||||
|
||||
class VectorStoreBase(ABC):
|
||||
"""base class for vector store database"""
|
||||
|
||||
@abstractmethod
|
||||
def load_document(self, documents) -> None:
|
||||
"""load document in vector database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def similar_search(self, text, topk) -> None:
|
||||
"""similar search in vector database."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""is vector store name exist."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids):
|
||||
"""delete vector by ids."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_vector_name(self, vector_name):
|
||||
"""delete vector name."""
|
||||
pass
|
||||
|
||||
def _normalization_vectors(self, vectors):
|
||||
"""normalization vectors to scale[0,1]"""
|
||||
import numpy as np
|
||||
|
||||
norm = np.linalg.norm(vectors)
|
||||
return vectors / norm
|
||||
|
||||
def _default_relevance_score_fn(self, distance: float) -> float:
|
||||
"""Return a similarity score on a scale [0, 1]."""
|
||||
return 1.0 - distance / math.sqrt(2)
|
99
dbgpt/storage/vector_store/chroma_store.py
Normal file
99
dbgpt/storage/vector_store/chroma_store.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from chromadb.config import Settings
|
||||
from chromadb import PersistentClient
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaStore(VectorStoreBase):
|
||||
"""chroma database"""
|
||||
|
||||
def __init__(self, ctx: {}) -> None:
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
self.ctx = ctx
|
||||
chroma_path = ctx.get(
|
||||
"CHROMA_PERSIST_PATH",
|
||||
os.path.join(PILOT_PATH, "data"),
|
||||
)
|
||||
self.persist_dir = os.path.join(
|
||||
chroma_path, ctx["vector_store_name"] + ".vectordb"
|
||||
)
|
||||
self.embeddings = ctx.get("embeddings", None)
|
||||
chroma_settings = Settings(
|
||||
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
|
||||
persist_directory=self.persist_dir,
|
||||
anonymized_telemetry=False,
|
||||
)
|
||||
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
|
||||
|
||||
collection_metadata = {"hnsw:space": "cosine"}
|
||||
self.vector_store_client = Chroma(
|
||||
persist_directory=self.persist_dir,
|
||||
embedding_function=self.embeddings,
|
||||
# client_settings=chroma_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||
logger.info("ChromaStore similar search")
|
||||
return self.vector_store_client.similarity_search(text, topk, **kwargs)
|
||||
|
||||
def similar_search_with_scores(self, text, topk, score_threshold) -> None:
|
||||
"""
|
||||
Chroma similar_search_with_score.
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
Args:
|
||||
text(str): query text
|
||||
topk(int): return docs nums. Defaults to 4.
|
||||
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
logger.info("ChromaStore similar search")
|
||||
docs_and_scores = (
|
||||
self.vector_store_client.similarity_search_with_relevance_scores(
|
||||
query=text, k=topk, score_threshold=score_threshold
|
||||
)
|
||||
)
|
||||
return docs_and_scores
|
||||
|
||||
def vector_name_exists(self):
|
||||
logger.info(f"Check persist_dir: {self.persist_dir}")
|
||||
if not os.path.exists(self.persist_dir):
|
||||
return False
|
||||
files = os.listdir(self.persist_dir)
|
||||
# Skip default file: chroma.sqlite3
|
||||
files = list(filter(lambda f: f != "chroma.sqlite3", files))
|
||||
return len(files) > 0
|
||||
|
||||
def load_document(self, documents):
|
||||
logger.info("ChromaStore load document")
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
|
||||
return ids
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
logger.info(f"chroma vector_name:{vector_name} begin delete...")
|
||||
self.vector_store_client.delete_collection()
|
||||
self._clean_persist_folder()
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
logger.info(f"begin delete chroma ids...")
|
||||
collection = self.vector_store_client._collection
|
||||
collection.delete(ids=ids)
|
||||
|
||||
def _clean_persist_folder(self):
|
||||
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
|
||||
for name in files:
|
||||
os.remove(os.path.join(root, name))
|
||||
for name in dirs:
|
||||
os.rmdir(os.path.join(root, name))
|
||||
os.rmdir(self.persist_dir)
|
83
dbgpt/storage/vector_store/connector.py
Normal file
83
dbgpt/storage/vector_store/connector.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from dbgpt.storage import vector_store
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
|
||||
connector = {}
|
||||
|
||||
|
||||
class VectorStoreConnector:
|
||||
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
|
||||
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
|
||||
2.similar_search: similarity search from vector_store
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, vector_store_type, ctx: {}) -> None:
|
||||
"""initialize vector store connector.
|
||||
Args:
|
||||
- vector_store_type: vector store type Milvus, Chroma, Weaviate
|
||||
- ctx: vector store config params.
|
||||
"""
|
||||
self.ctx = ctx
|
||||
self._register()
|
||||
|
||||
if self._match(vector_store_type):
|
||||
self.connector_class = connector.get(vector_store_type)
|
||||
else:
|
||||
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
|
||||
|
||||
print(self.connector_class)
|
||||
self.client = self.connector_class(ctx)
|
||||
|
||||
def load_document(self, docs):
|
||||
"""load document in vector database."""
|
||||
return self.client.load_document(docs)
|
||||
|
||||
def similar_search(self, doc: str, topk: int):
|
||||
"""similar search in vector database.
|
||||
Args:
|
||||
- doc: query text
|
||||
- topk: topk
|
||||
"""
|
||||
return self.client.similar_search(doc, topk)
|
||||
|
||||
def similar_search_with_scores(self, doc: str, topk: int, score_threshold: float):
|
||||
"""
|
||||
similar_search_with_score in vector database..
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
Args:
|
||||
doc(str): query text
|
||||
topk(int): return docs nums. Defaults to 4.
|
||||
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
return self.client.similar_search_with_scores(doc, topk, score_threshold)
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""is vector store name exist."""
|
||||
return self.client.vector_name_exists()
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
"""vector store delete
|
||||
Args:
|
||||
- vector_name: vector store name
|
||||
"""
|
||||
return self.client.delete_vector_name(vector_name)
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
"""vector store delete by ids.
|
||||
Args:
|
||||
- ids: vector ids
|
||||
"""
|
||||
return self.client.delete_by_ids(ids=ids)
|
||||
|
||||
def _match(self, vector_store_type) -> bool:
|
||||
if connector.get(vector_store_type):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _register(self):
|
||||
for cls in vector_store.__all__:
|
||||
if issubclass(getattr(vector_store, cls), VectorStoreBase):
|
||||
_k, _v = cls, getattr(vector_store, cls)
|
||||
connector.update({_k: _v})
|
374
dbgpt/storage/vector_store/milvus_store.py
Normal file
374
dbgpt/storage/vector_store/milvus_store.py
Normal file
@@ -0,0 +1,374 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusStore(VectorStoreBase):
|
||||
"""Milvus database"""
|
||||
|
||||
def __init__(self, ctx: {}) -> None:
|
||||
"""MilvusStore init."""
|
||||
from pymilvus import connections
|
||||
|
||||
"""init a milvus storage connection.
|
||||
|
||||
Args:
|
||||
ctx ({}): MilvusStore global config.
|
||||
"""
|
||||
# self.configure(cfg)
|
||||
|
||||
connect_kwargs = {}
|
||||
self.uri = ctx.get("MILVUS_URL", os.getenv("MILVUS_URL"))
|
||||
self.port = ctx.get("MILVUS_PORT", os.getenv("MILVUS_PORT"))
|
||||
self.username = ctx.get("MILVUS_USERNAME", os.getenv("MILVUS_USERNAME"))
|
||||
self.password = ctx.get("MILVUS_PASSWORD", os.getenv("MILVUS_PASSWORD"))
|
||||
self.secure = ctx.get("MILVUS_SECURE", os.getenv("MILVUS_SECURE"))
|
||||
self.collection_name = ctx.get("vector_store_name", None)
|
||||
self.embedding = ctx.get("embeddings", None)
|
||||
self.fields = []
|
||||
self.alias = "default"
|
||||
|
||||
# use HNSW by default.
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
# use HNSW by default.
|
||||
self.index_params_map = {
|
||||
"IVF_FLAT": {"params": {"nprobe": 10}},
|
||||
"IVF_SQ8": {"params": {"nprobe": 10}},
|
||||
"IVF_PQ": {"params": {"nprobe": 10}},
|
||||
"HNSW": {"params": {"ef": 10}},
|
||||
"RHNSW_FLAT": {"params": {"ef": 10}},
|
||||
"RHNSW_SQ": {"params": {"ef": 10}},
|
||||
"RHNSW_PQ": {"params": {"ef": 10}},
|
||||
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
||||
"ANNOY": {"params": {"search_k": 10}},
|
||||
}
|
||||
# default collection schema
|
||||
self.primary_field = "pk_id"
|
||||
self.vector_field = "vector"
|
||||
self.text_field = "content"
|
||||
self.metadata_field = "metadata"
|
||||
|
||||
if (self.username is None) != (self.password is None):
|
||||
raise ValueError(
|
||||
"Both username and password must be set to use authentication for Milvus"
|
||||
)
|
||||
if self.username:
|
||||
connect_kwargs["user"] = self.username
|
||||
connect_kwargs["password"] = self.password
|
||||
|
||||
connections.connect(
|
||||
host=self.uri or "127.0.0.1",
|
||||
port=self.port or "19530",
|
||||
alias="default"
|
||||
# secure=self.secure,
|
||||
)
|
||||
|
||||
def init_schema_and_load(self, vector_name, documents):
|
||||
"""Create a Milvus collection, indexes it with HNSW, load document.
|
||||
Args:
|
||||
vector_name (Embeddings): your collection name.
|
||||
documents (List[str]): Text to insert.
|
||||
Returns:
|
||||
VectorStore: The MilvusStore vector store.
|
||||
"""
|
||||
try:
|
||||
from pymilvus import (
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
connections,
|
||||
utility,
|
||||
)
|
||||
from pymilvus.orm.types import infer_dtype_bydata
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
if not connections.has_connection("default"):
|
||||
connections.connect(
|
||||
host=self.uri or "127.0.0.1",
|
||||
port=self.port or "19530",
|
||||
alias="default"
|
||||
# secure=self.secure,
|
||||
)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
embeddings = self.embedding.embed_query(texts[0])
|
||||
|
||||
if utility.has_collection(self.collection_name):
|
||||
self.col = Collection(self.collection_name, using=self.alias)
|
||||
self.fields = []
|
||||
for x in self.col.schema.fields:
|
||||
self.fields.append(x.name)
|
||||
if x.auto_id:
|
||||
self.fields.remove(x.name)
|
||||
if x.is_primary:
|
||||
self.primary_field = x.name
|
||||
if (
|
||||
x.dtype == DataType.FLOAT_VECTOR
|
||||
or x.dtype == DataType.BINARY_VECTOR
|
||||
):
|
||||
self.vector_field = x.name
|
||||
return self._add_documents(texts, metadatas)
|
||||
# return self.collection_name
|
||||
|
||||
dim = len(embeddings)
|
||||
# Generate unique names
|
||||
primary_field = self.primary_field
|
||||
vector_field = self.vector_field
|
||||
text_field = self.text_field
|
||||
metadata_field = self.metadata_field
|
||||
# self.text_field = text_field
|
||||
collection_name = vector_name
|
||||
fields = []
|
||||
max_length = 0
|
||||
for y in texts:
|
||||
max_length = max(max_length, len(y))
|
||||
# Create the text field
|
||||
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
|
||||
# primary key field
|
||||
fields.append(
|
||||
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
||||
)
|
||||
# vector field
|
||||
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||
|
||||
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
|
||||
schema = CollectionSchema(fields)
|
||||
# Create the collection
|
||||
collection = Collection(collection_name, schema)
|
||||
self.col = collection
|
||||
# index parameters for the collection
|
||||
index = self.index_params
|
||||
# milvus index
|
||||
collection.create_index(vector_field, index)
|
||||
collection.load()
|
||||
schema = collection.schema
|
||||
for x in schema.fields:
|
||||
self.fields.append(x.name)
|
||||
if x.auto_id:
|
||||
self.fields.remove(x.name)
|
||||
if x.is_primary:
|
||||
self.primary_field = x.name
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
ids = self._add_documents(texts, metadatas)
|
||||
|
||||
return ids
|
||||
|
||||
def _add_documents(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
partition_name: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
"""add text data into Milvus."""
|
||||
insert_dict: Any = {self.text_field: list(texts)}
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
text_vector = self.embedding.embed_documents(list(texts))
|
||||
insert_dict[self.vector_field] = self._normalization_vectors(text_vector)
|
||||
except NotImplementedError:
|
||||
insert_dict[self.vector_field] = [
|
||||
self.embedding.embed_query(x) for x in texts
|
||||
]
|
||||
# Collect the metadata into the insert dict.
|
||||
# self.fields.extend(metadatas[0].keys())
|
||||
if len(self.fields) > 2 and metadatas is not None:
|
||||
for d in metadatas:
|
||||
# for key, value in d.items():
|
||||
insert_dict.setdefault("metadata", []).append(json.dumps(d))
|
||||
# Convert dict to list of lists for insertion
|
||||
insert_list = [insert_dict[x] for x in self.fields]
|
||||
# Insert into the collection.
|
||||
res = self.col.insert(
|
||||
insert_list, partition_name=partition_name, timeout=timeout
|
||||
)
|
||||
# make sure data is searchable.
|
||||
self.col.flush()
|
||||
return res.primary_keys
|
||||
|
||||
def load_document(self, documents) -> None:
|
||||
"""load document in vector database."""
|
||||
# self.init_schema_and_load(self.collection_name, documents)
|
||||
batch_size = 500
|
||||
batched_list = [
|
||||
documents[i : i + batch_size] for i in range(0, len(documents), batch_size)
|
||||
]
|
||||
doc_ids = []
|
||||
for doc_batch in batched_list:
|
||||
doc_ids.extend(self.init_schema_and_load(self.collection_name, doc_batch))
|
||||
doc_ids = [str(doc_id) for doc_id in doc_ids]
|
||||
return doc_ids
|
||||
|
||||
def similar_search(self, text, topk):
|
||||
from pymilvus import Collection, DataType
|
||||
|
||||
"""similar_search in vector database."""
|
||||
self.col = Collection(self.collection_name)
|
||||
schema = self.col.schema
|
||||
for x in schema.fields:
|
||||
self.fields.append(x.name)
|
||||
if x.auto_id:
|
||||
self.fields.remove(x.name)
|
||||
if x.is_primary:
|
||||
self.primary_field = x.name
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
_, docs_and_scores = self._search(text, topk)
|
||||
from langchain.schema import Document
|
||||
|
||||
return [
|
||||
Document(
|
||||
metadata=json.loads(doc.metadata.get("metadata", "")),
|
||||
page_content=doc.page_content,
|
||||
)
|
||||
for doc, _, _ in docs_and_scores
|
||||
]
|
||||
|
||||
def similar_search_with_scores(self, text, topk, score_threshold):
|
||||
"""Perform a search on a query string and return results with score.
|
||||
|
||||
For more information about the search parameters, take a look at the pymilvus
|
||||
documentation found here:
|
||||
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector being searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
expr (str, optional): Filtering expression. Defaults to None.
|
||||
timeout (int, optional): How long to wait before timeout error.
|
||||
Defaults to None.
|
||||
kwargs: Collection.search() keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: Result doc and score.
|
||||
"""
|
||||
from pymilvus import Collection
|
||||
|
||||
self.col = Collection(self.collection_name)
|
||||
schema = self.col.schema
|
||||
for x in schema.fields:
|
||||
self.fields.append(x.name)
|
||||
if x.auto_id:
|
||||
self.fields.remove(x.name)
|
||||
if x.is_primary:
|
||||
self.primary_field = x.name
|
||||
from pymilvus import DataType
|
||||
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
_, docs_and_scores = self._search(text, topk)
|
||||
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
|
||||
)
|
||||
|
||||
if score_threshold is not None:
|
||||
docs_and_scores = [
|
||||
(doc, score)
|
||||
for doc, score, id in docs_and_scores
|
||||
if score >= score_threshold
|
||||
]
|
||||
if len(docs_and_scores) == 0:
|
||||
warnings.warn(
|
||||
"No relevant docs were retrieved using the relevance score"
|
||||
f" threshold {score_threshold}"
|
||||
)
|
||||
return docs_and_scores
|
||||
|
||||
def _search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
partition_names: Optional[List[str]] = None,
|
||||
round_decimal: int = -1,
|
||||
timeout: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
self.col.load()
|
||||
# use default index params.
|
||||
if param is None:
|
||||
index_type = self.col.indexes[0].params["index_type"]
|
||||
param = self.index_params_map[index_type]
|
||||
# query text embedding.
|
||||
query_vector = self.embedding.embed_query(query)
|
||||
data = [self._normalization_vectors(query_vector)]
|
||||
# Determine result metadata fields.
|
||||
output_fields = self.fields[:]
|
||||
output_fields.remove(self.vector_field)
|
||||
# milvus search.
|
||||
res = self.col.search(
|
||||
data,
|
||||
self.vector_field,
|
||||
param,
|
||||
k,
|
||||
expr=expr,
|
||||
output_fields=output_fields,
|
||||
partition_names=partition_names,
|
||||
round_decimal=round_decimal,
|
||||
timeout=60,
|
||||
**kwargs,
|
||||
)
|
||||
ret = []
|
||||
for result in res[0]:
|
||||
meta = {x: result.entity.get(x) for x in output_fields}
|
||||
ret.append(
|
||||
(
|
||||
Document(page_content=meta.pop(self.text_field), metadata=meta),
|
||||
self._default_relevance_score_fn(result.distance),
|
||||
result.id,
|
||||
)
|
||||
)
|
||||
|
||||
return data[0], ret
|
||||
|
||||
def vector_name_exists(self):
|
||||
from pymilvus import utility
|
||||
|
||||
"""is vector store name exist."""
|
||||
return utility.has_collection(self.collection_name)
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
from pymilvus import utility
|
||||
|
||||
"""milvus delete collection name"""
|
||||
logger.info(f"milvus vector_name:{vector_name} begin delete...")
|
||||
utility.drop_collection(vector_name)
|
||||
return True
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
from pymilvus import Collection
|
||||
|
||||
self.col = Collection(self.collection_name)
|
||||
"""milvus delete vectors by ids"""
|
||||
logger.info(f"begin delete milvus ids...")
|
||||
delete_ids = ids.split(",")
|
||||
doc_ids = [int(doc_id) for doc_id in delete_ids]
|
||||
delet_expr = f"{self.primary_field} in {doc_ids}"
|
||||
self.col.delete(delet_expr)
|
||||
return True
|
51
dbgpt/storage/vector_store/pgvector_store.py
Normal file
51
dbgpt/storage/vector_store/pgvector_store.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Any
|
||||
import logging
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class PGVectorStore(VectorStoreBase):
|
||||
"""`Postgres.PGVector` vector store.
|
||||
|
||||
To use this, you should have the ``pgvector`` python package installed.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx: dict) -> None:
|
||||
"""init pgvector storage"""
|
||||
|
||||
from langchain.vectorstores import PGVector
|
||||
|
||||
self.ctx = ctx
|
||||
self.connection_string = ctx.get("connection_string", None)
|
||||
self.embeddings = ctx.get("embeddings", None)
|
||||
self.collection_name = ctx.get("vector_store_name", None)
|
||||
|
||||
self.vector_store_client = PGVector(
|
||||
embedding_function=self.embeddings,
|
||||
collection_name=self.collection_name,
|
||||
connection_string=self.connection_string,
|
||||
)
|
||||
|
||||
def similar_search(self, text, topk, **kwargs: Any) -> None:
|
||||
return self.vector_store_client.similarity_search(text, topk)
|
||||
|
||||
def vector_name_exists(self):
|
||||
try:
|
||||
self.vector_store_client.create_collection()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("vector_name_exists error", e.message)
|
||||
return False
|
||||
|
||||
def load_document(self, documents) -> None:
|
||||
return self.vector_store_client.from_documents(documents)
|
||||
|
||||
def delete_vector_name(self, vector_name):
|
||||
return self.vector_store_client.delete_collection()
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
return self.vector_store_client.delete(ids)
|
143
dbgpt/storage/vector_store/weaviate_store.py
Normal file
143
dbgpt/storage/vector_store/weaviate_store.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import os
|
||||
import logging
|
||||
from langchain.schema import Document
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class WeaviateStore(VectorStoreBase):
|
||||
"""Weaviate database"""
|
||||
|
||||
def __init__(self, ctx: dict) -> None:
|
||||
"""Initialize with Weaviate client."""
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
|
||||
self.ctx = ctx
|
||||
self.weaviate_url = ctx.get("WEAVIATE_URL", os.getenv("WEAVIATE_URL"))
|
||||
self.embedding = ctx.get("embeddings", None)
|
||||
self.vector_name = ctx["vector_store_name"]
|
||||
self.persist_dir = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb"
|
||||
)
|
||||
|
||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||
|
||||
def similar_search(self, text: str, topk: int) -> None:
|
||||
"""Perform similar search in Weaviate"""
|
||||
logger.info("Weaviate similar search")
|
||||
# nearText = {
|
||||
# "concepts": [text],
|
||||
# "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance"
|
||||
# }
|
||||
# vector = self.embedding.embed_query(text)
|
||||
response = (
|
||||
self.vector_store_client.query.get(
|
||||
self.vector_name, ["metadata", "page_content"]
|
||||
)
|
||||
# .with_near_vector({"vector": vector})
|
||||
.with_limit(topk).do()
|
||||
)
|
||||
res = response["data"]["Get"][list(response["data"]["Get"].keys())[0]]
|
||||
docs = []
|
||||
for r in res:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=r["page_content"],
|
||||
metadata={"metadata": r["metadata"]},
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""Check if a vector name exists for a given class in Weaviate.
|
||||
Returns:
|
||||
bool: True if the vector name exists, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if self.vector_store_client.schema.get(self.vector_name):
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("vector_name_exists error", e.message)
|
||||
return False
|
||||
|
||||
def _default_schema(self) -> None:
|
||||
"""
|
||||
Create the schema for Weaviate with a Document class containing metadata and text properties.
|
||||
"""
|
||||
|
||||
schema = {
|
||||
"classes": [
|
||||
{
|
||||
"class": self.vector_name,
|
||||
"description": "A document with metadata and text",
|
||||
# "moduleConfig": {
|
||||
# "text2vec-transformers": {
|
||||
# "poolingStrategy": "masked_mean",
|
||||
# "vectorizeClassName": False,
|
||||
# }
|
||||
# },
|
||||
"properties": [
|
||||
{
|
||||
"dataType": ["text"],
|
||||
# "moduleConfig": {
|
||||
# "text2vec-transformers": {
|
||||
# "skip": False,
|
||||
# "vectorizePropertyName": False,
|
||||
# }
|
||||
# },
|
||||
"description": "Metadata of the document",
|
||||
"name": "metadata",
|
||||
},
|
||||
{
|
||||
"dataType": ["text"],
|
||||
# "moduleConfig": {
|
||||
# "text2vec-transformers": {
|
||||
# "skip": False,
|
||||
# "vectorizePropertyName": False,
|
||||
# }
|
||||
# },
|
||||
"description": "Text content of the document",
|
||||
"name": "page_content",
|
||||
},
|
||||
],
|
||||
# "vectorizer": "text2vec-transformers",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Create the schema in Weaviate
|
||||
self.vector_store_client.schema.create(schema)
|
||||
|
||||
def load_document(self, documents: list) -> None:
|
||||
"""Load documents into Weaviate"""
|
||||
logger.info("Weaviate load document")
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
|
||||
# Import data
|
||||
with self.vector_store_client.batch as batch:
|
||||
batch.batch_size = 100
|
||||
|
||||
# Batch import all documents
|
||||
for i in range(len(texts)):
|
||||
properties = {
|
||||
"metadata": metadatas[i]["source"],
|
||||
"page_content": texts[i],
|
||||
}
|
||||
|
||||
self.vector_store_client.batch.add_data_object(
|
||||
data_object=properties, class_name=self.vector_name
|
||||
)
|
||||
self.vector_store_client.batch.flush()
|
Reference in New Issue
Block a user