refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

@@ -0,0 +1,41 @@
from typing import Any
def _import_pgvector() -> Any:
from dbgpt.storage.vector_store.pgvector_store import PGVectorStore
return PGVectorStore
def _import_milvus() -> Any:
from dbgpt.storage.vector_store.milvus_store import MilvusStore
return MilvusStore
def _import_chroma() -> Any:
from dbgpt.storage.vector_store.chroma_store import ChromaStore
return ChromaStore
def _import_weaviate() -> Any:
from dbgpt.storage.vector_store.weaviate_store import WeaviateStore
return WeaviateStore
def __getattr__(name: str) -> Any:
if name == "Chroma":
return _import_chroma()
elif name == "Milvus":
return _import_milvus()
elif name == "Weaviate":
return _import_weaviate()
elif name == "PGVector":
return _import_pgvector()
else:
raise AttributeError(f"Could not find: {name}")
__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector"]

View File

@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
import math
class VectorStoreBase(ABC):
"""base class for vector store database"""
@abstractmethod
def load_document(self, documents) -> None:
"""load document in vector database."""
pass
@abstractmethod
def similar_search(self, text, topk) -> None:
"""similar search in vector database."""
pass
@abstractmethod
def vector_name_exists(self) -> bool:
"""is vector store name exist."""
return False
@abstractmethod
def delete_by_ids(self, ids):
"""delete vector by ids."""
pass
@abstractmethod
def delete_vector_name(self, vector_name):
"""delete vector name."""
pass
def _normalization_vectors(self, vectors):
"""normalization vectors to scale[0,1]"""
import numpy as np
norm = np.linalg.norm(vectors)
return vectors / norm
def _default_relevance_score_fn(self, distance: float) -> float:
"""Return a similarity score on a scale [0, 1]."""
return 1.0 - distance / math.sqrt(2)

View File

@@ -0,0 +1,99 @@
import os
import logging
from typing import Any
from chromadb.config import Settings
from chromadb import PersistentClient
from dbgpt.storage.vector_store.base import VectorStoreBase
from dbgpt.configs.model_config import PILOT_PATH
logger = logging.getLogger(__name__)
class ChromaStore(VectorStoreBase):
"""chroma database"""
def __init__(self, ctx: {}) -> None:
from langchain.vectorstores import Chroma
self.ctx = ctx
chroma_path = ctx.get(
"CHROMA_PERSIST_PATH",
os.path.join(PILOT_PATH, "data"),
)
self.persist_dir = os.path.join(
chroma_path, ctx["vector_store_name"] + ".vectordb"
)
self.embeddings = ctx.get("embeddings", None)
chroma_settings = Settings(
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
persist_directory=self.persist_dir,
anonymized_telemetry=False,
)
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
collection_metadata = {"hnsw:space": "cosine"}
self.vector_store_client = Chroma(
persist_directory=self.persist_dir,
embedding_function=self.embeddings,
# client_settings=chroma_settings,
client=client,
collection_metadata=collection_metadata,
)
def similar_search(self, text, topk, **kwargs: Any) -> None:
logger.info("ChromaStore similar search")
return self.vector_store_client.similarity_search(text, topk, **kwargs)
def similar_search_with_scores(self, text, topk, score_threshold) -> None:
"""
Chroma similar_search_with_score.
Return docs and relevance scores in the range [0, 1].
Args:
text(str): query text
topk(int): return docs nums. Defaults to 4.
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
"""
logger.info("ChromaStore similar search")
docs_and_scores = (
self.vector_store_client.similarity_search_with_relevance_scores(
query=text, k=topk, score_threshold=score_threshold
)
)
return docs_and_scores
def vector_name_exists(self):
logger.info(f"Check persist_dir: {self.persist_dir}")
if not os.path.exists(self.persist_dir):
return False
files = os.listdir(self.persist_dir)
# Skip default file: chroma.sqlite3
files = list(filter(lambda f: f != "chroma.sqlite3", files))
return len(files) > 0
def load_document(self, documents):
logger.info("ChromaStore load document")
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
ids = self.vector_store_client.add_texts(texts=texts, metadatas=metadatas)
return ids
def delete_vector_name(self, vector_name):
logger.info(f"chroma vector_name:{vector_name} begin delete...")
self.vector_store_client.delete_collection()
self._clean_persist_folder()
return True
def delete_by_ids(self, ids):
logger.info(f"begin delete chroma ids...")
collection = self.vector_store_client._collection
collection.delete(ids=ids)
def _clean_persist_folder(self):
for root, dirs, files in os.walk(self.persist_dir, topdown=False):
for name in files:
os.remove(os.path.join(root, name))
for name in dirs:
os.rmdir(os.path.join(root, name))
os.rmdir(self.persist_dir)

View File

@@ -0,0 +1,83 @@
from dbgpt.storage import vector_store
from dbgpt.storage.vector_store.base import VectorStoreBase
connector = {}
class VectorStoreConnector:
"""VectorStoreConnector, can connect different vector db provided load document api_v1 and similar search api_v1.
1.load_document:knowledge document source into vector store.(Chroma, Milvus, Weaviate)
2.similar_search: similarity search from vector_store
"""
def __init__(self, vector_store_type, ctx: {}) -> None:
"""initialize vector store connector.
Args:
- vector_store_type: vector store type Milvus, Chroma, Weaviate
- ctx: vector store config params.
"""
self.ctx = ctx
self._register()
if self._match(vector_store_type):
self.connector_class = connector.get(vector_store_type)
else:
raise Exception(f"Vector Type Not support. {0}", vector_store_type)
print(self.connector_class)
self.client = self.connector_class(ctx)
def load_document(self, docs):
"""load document in vector database."""
return self.client.load_document(docs)
def similar_search(self, doc: str, topk: int):
"""similar search in vector database.
Args:
- doc: query text
- topk: topk
"""
return self.client.similar_search(doc, topk)
def similar_search_with_scores(self, doc: str, topk: int, score_threshold: float):
"""
similar_search_with_score in vector database..
Return docs and relevance scores in the range [0, 1].
Args:
doc(str): query text
topk(int): return docs nums. Defaults to 4.
score_threshold(float): score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs,0 is dissimilar, 1 is most similar.
"""
return self.client.similar_search_with_scores(doc, topk, score_threshold)
def vector_name_exists(self):
"""is vector store name exist."""
return self.client.vector_name_exists()
def delete_vector_name(self, vector_name):
"""vector store delete
Args:
- vector_name: vector store name
"""
return self.client.delete_vector_name(vector_name)
def delete_by_ids(self, ids):
"""vector store delete by ids.
Args:
- ids: vector ids
"""
return self.client.delete_by_ids(ids=ids)
def _match(self, vector_store_type) -> bool:
if connector.get(vector_store_type):
return True
else:
return False
def _register(self):
for cls in vector_store.__all__:
if issubclass(getattr(vector_store, cls), VectorStoreBase):
_k, _v = cls, getattr(vector_store, cls)
connector.update({_k: _v})

View File

@@ -0,0 +1,374 @@
from __future__ import annotations
import json
import logging
import os
from typing import Any, Iterable, List, Optional, Tuple
from dbgpt.storage.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class MilvusStore(VectorStoreBase):
"""Milvus database"""
def __init__(self, ctx: {}) -> None:
"""MilvusStore init."""
from pymilvus import connections
"""init a milvus storage connection.
Args:
ctx ({}): MilvusStore global config.
"""
# self.configure(cfg)
connect_kwargs = {}
self.uri = ctx.get("MILVUS_URL", os.getenv("MILVUS_URL"))
self.port = ctx.get("MILVUS_PORT", os.getenv("MILVUS_PORT"))
self.username = ctx.get("MILVUS_USERNAME", os.getenv("MILVUS_USERNAME"))
self.password = ctx.get("MILVUS_PASSWORD", os.getenv("MILVUS_PASSWORD"))
self.secure = ctx.get("MILVUS_SECURE", os.getenv("MILVUS_SECURE"))
self.collection_name = ctx.get("vector_store_name", None)
self.embedding = ctx.get("embeddings", None)
self.fields = []
self.alias = "default"
# use HNSW by default.
self.index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
# use HNSW by default.
self.index_params_map = {
"IVF_FLAT": {"params": {"nprobe": 10}},
"IVF_SQ8": {"params": {"nprobe": 10}},
"IVF_PQ": {"params": {"nprobe": 10}},
"HNSW": {"params": {"ef": 10}},
"RHNSW_FLAT": {"params": {"ef": 10}},
"RHNSW_SQ": {"params": {"ef": 10}},
"RHNSW_PQ": {"params": {"ef": 10}},
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"params": {"search_k": 10}},
}
# default collection schema
self.primary_field = "pk_id"
self.vector_field = "vector"
self.text_field = "content"
self.metadata_field = "metadata"
if (self.username is None) != (self.password is None):
raise ValueError(
"Both username and password must be set to use authentication for Milvus"
)
if self.username:
connect_kwargs["user"] = self.username
connect_kwargs["password"] = self.password
connections.connect(
host=self.uri or "127.0.0.1",
port=self.port or "19530",
alias="default"
# secure=self.secure,
)
def init_schema_and_load(self, vector_name, documents):
"""Create a Milvus collection, indexes it with HNSW, load document.
Args:
vector_name (Embeddings): your collection name.
documents (List[str]): Text to insert.
Returns:
VectorStore: The MilvusStore vector store.
"""
try:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
connections,
utility,
)
from pymilvus.orm.types import infer_dtype_bydata
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
if not connections.has_connection("default"):
connections.connect(
host=self.uri or "127.0.0.1",
port=self.port or "19530",
alias="default"
# secure=self.secure,
)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
embeddings = self.embedding.embed_query(texts[0])
if utility.has_collection(self.collection_name):
self.col = Collection(self.collection_name, using=self.alias)
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if (
x.dtype == DataType.FLOAT_VECTOR
or x.dtype == DataType.BINARY_VECTOR
):
self.vector_field = x.name
return self._add_documents(texts, metadatas)
# return self.collection_name
dim = len(embeddings)
# Generate unique names
primary_field = self.primary_field
vector_field = self.vector_field
text_field = self.text_field
metadata_field = self.metadata_field
# self.text_field = text_field
collection_name = vector_name
fields = []
max_length = 0
for y in texts:
max_length = max(max_length, len(y))
# Create the text field
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
# primary key field
fields.append(
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
)
# vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
schema = CollectionSchema(fields)
# Create the collection
collection = Collection(collection_name, schema)
self.col = collection
# index parameters for the collection
index = self.index_params
# milvus index
collection.create_index(vector_field, index)
collection.load()
schema = collection.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
ids = self._add_documents(texts, metadatas)
return ids
def _add_documents(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
partition_name: Optional[str] = None,
timeout: Optional[int] = None,
) -> List[str]:
"""add text data into Milvus."""
insert_dict: Any = {self.text_field: list(texts)}
try:
import numpy as np
text_vector = self.embedding.embed_documents(list(texts))
insert_dict[self.vector_field] = self._normalization_vectors(text_vector)
except NotImplementedError:
insert_dict[self.vector_field] = [
self.embedding.embed_query(x) for x in texts
]
# Collect the metadata into the insert dict.
# self.fields.extend(metadatas[0].keys())
if len(self.fields) > 2 and metadatas is not None:
for d in metadatas:
# for key, value in d.items():
insert_dict.setdefault("metadata", []).append(json.dumps(d))
# Convert dict to list of lists for insertion
insert_list = [insert_dict[x] for x in self.fields]
# Insert into the collection.
res = self.col.insert(
insert_list, partition_name=partition_name, timeout=timeout
)
# make sure data is searchable.
self.col.flush()
return res.primary_keys
def load_document(self, documents) -> None:
"""load document in vector database."""
# self.init_schema_and_load(self.collection_name, documents)
batch_size = 500
batched_list = [
documents[i : i + batch_size] for i in range(0, len(documents), batch_size)
]
doc_ids = []
for doc_batch in batched_list:
doc_ids.extend(self.init_schema_and_load(self.collection_name, doc_batch))
doc_ids = [str(doc_id) for doc_id in doc_ids]
return doc_ids
def similar_search(self, text, topk):
from pymilvus import Collection, DataType
"""similar_search in vector database."""
self.col = Collection(self.collection_name)
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
from langchain.schema import Document
return [
Document(
metadata=json.loads(doc.metadata.get("metadata", "")),
page_content=doc.page_content,
)
for doc, _, _ in docs_and_scores
]
def similar_search_with_scores(self, text, topk, score_threshold):
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
from pymilvus import Collection
self.col = Collection(self.collection_name)
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
from pymilvus import DataType
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
if any(score < 0.0 or score > 1.0 for _, score, id in docs_and_scores):
import warnings
warnings.warn(
"similarity score need between" f" 0 and 1, got {docs_and_scores}"
)
if score_threshold is not None:
docs_and_scores = [
(doc, score)
for doc, score, id in docs_and_scores
if score >= score_threshold
]
if len(docs_and_scores) == 0:
warnings.warn(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)
return docs_and_scores
def _search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
partition_names: Optional[List[str]] = None,
round_decimal: int = -1,
timeout: Optional[int] = None,
**kwargs: Any,
):
from langchain.docstore.document import Document
self.col.load()
# use default index params.
if param is None:
index_type = self.col.indexes[0].params["index_type"]
param = self.index_params_map[index_type]
# query text embedding.
query_vector = self.embedding.embed_query(query)
data = [self._normalization_vectors(query_vector)]
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self.vector_field)
# milvus search.
res = self.col.search(
data,
self.vector_field,
param,
k,
expr=expr,
output_fields=output_fields,
partition_names=partition_names,
round_decimal=round_decimal,
timeout=60,
**kwargs,
)
ret = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
ret.append(
(
Document(page_content=meta.pop(self.text_field), metadata=meta),
self._default_relevance_score_fn(result.distance),
result.id,
)
)
return data[0], ret
def vector_name_exists(self):
from pymilvus import utility
"""is vector store name exist."""
return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name):
from pymilvus import utility
"""milvus delete collection name"""
logger.info(f"milvus vector_name:{vector_name} begin delete...")
utility.drop_collection(vector_name)
return True
def delete_by_ids(self, ids):
from pymilvus import Collection
self.col = Collection(self.collection_name)
"""milvus delete vectors by ids"""
logger.info(f"begin delete milvus ids...")
delete_ids = ids.split(",")
doc_ids = [int(doc_id) for doc_id in delete_ids]
delet_expr = f"{self.primary_field} in {doc_ids}"
self.col.delete(delet_expr)
return True

View File

@@ -0,0 +1,51 @@
from typing import Any
import logging
from dbgpt.storage.vector_store.base import VectorStoreBase
from dbgpt._private.config import Config
logger = logging.getLogger(__name__)
CFG = Config()
class PGVectorStore(VectorStoreBase):
"""`Postgres.PGVector` vector store.
To use this, you should have the ``pgvector`` python package installed.
"""
def __init__(self, ctx: dict) -> None:
"""init pgvector storage"""
from langchain.vectorstores import PGVector
self.ctx = ctx
self.connection_string = ctx.get("connection_string", None)
self.embeddings = ctx.get("embeddings", None)
self.collection_name = ctx.get("vector_store_name", None)
self.vector_store_client = PGVector(
embedding_function=self.embeddings,
collection_name=self.collection_name,
connection_string=self.connection_string,
)
def similar_search(self, text, topk, **kwargs: Any) -> None:
return self.vector_store_client.similarity_search(text, topk)
def vector_name_exists(self):
try:
self.vector_store_client.create_collection()
return True
except Exception as e:
logger.error("vector_name_exists error", e.message)
return False
def load_document(self, documents) -> None:
return self.vector_store_client.from_documents(documents)
def delete_vector_name(self, vector_name):
return self.vector_store_client.delete_collection()
def delete_by_ids(self, ids):
return self.vector_store_client.delete(ids)

View File

@@ -0,0 +1,143 @@
import os
import logging
from langchain.schema import Document
from dbgpt._private.config import Config
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from dbgpt.storage.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
CFG = Config()
class WeaviateStore(VectorStoreBase):
"""Weaviate database"""
def __init__(self, ctx: dict) -> None:
"""Initialize with Weaviate client."""
try:
import weaviate
except ImportError:
raise ValueError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
self.ctx = ctx
self.weaviate_url = ctx.get("WEAVIATE_URL", os.getenv("WEAVIATE_URL"))
self.embedding = ctx.get("embeddings", None)
self.vector_name = ctx["vector_store_name"]
self.persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb"
)
self.vector_store_client = weaviate.Client(self.weaviate_url)
def similar_search(self, text: str, topk: int) -> None:
"""Perform similar search in Weaviate"""
logger.info("Weaviate similar search")
# nearText = {
# "concepts": [text],
# "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance"
# }
# vector = self.embedding.embed_query(text)
response = (
self.vector_store_client.query.get(
self.vector_name, ["metadata", "page_content"]
)
# .with_near_vector({"vector": vector})
.with_limit(topk).do()
)
res = response["data"]["Get"][list(response["data"]["Get"].keys())[0]]
docs = []
for r in res:
docs.append(
Document(
page_content=r["page_content"],
metadata={"metadata": r["metadata"]},
)
)
return docs
def vector_name_exists(self) -> bool:
"""Check if a vector name exists for a given class in Weaviate.
Returns:
bool: True if the vector name exists, False otherwise.
"""
try:
if self.vector_store_client.schema.get(self.vector_name):
return True
return False
except Exception as e:
logger.error("vector_name_exists error", e.message)
return False
def _default_schema(self) -> None:
"""
Create the schema for Weaviate with a Document class containing metadata and text properties.
"""
schema = {
"classes": [
{
"class": self.vector_name,
"description": "A document with metadata and text",
# "moduleConfig": {
# "text2vec-transformers": {
# "poolingStrategy": "masked_mean",
# "vectorizeClassName": False,
# }
# },
"properties": [
{
"dataType": ["text"],
# "moduleConfig": {
# "text2vec-transformers": {
# "skip": False,
# "vectorizePropertyName": False,
# }
# },
"description": "Metadata of the document",
"name": "metadata",
},
{
"dataType": ["text"],
# "moduleConfig": {
# "text2vec-transformers": {
# "skip": False,
# "vectorizePropertyName": False,
# }
# },
"description": "Text content of the document",
"name": "page_content",
},
],
# "vectorizer": "text2vec-transformers",
}
]
}
# Create the schema in Weaviate
self.vector_store_client.schema.create(schema)
def load_document(self, documents: list) -> None:
"""Load documents into Weaviate"""
logger.info("Weaviate load document")
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
# Import data
with self.vector_store_client.batch as batch:
batch.batch_size = 100
# Batch import all documents
for i in range(len(texts)):
properties = {
"metadata": metadatas[i]["source"],
"page_content": texts[i],
}
self.vector_store_client.batch.add_data_object(
data_object=properties, class_name=self.vector_name
)
self.vector_store_client.batch.flush()