mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 21:12:13 +00:00
324 lines
12 KiB
Python
324 lines
12 KiB
Python
from typing import Any, Iterable, List, Optional, Tuple
|
|
|
|
from langchain.docstore.document import Document
|
|
from pymilvus import Collection, DataType, connections, utility
|
|
|
|
from pilot.configs.config import Config
|
|
from pilot.vector_store.vector_store_base import VectorStoreBase
|
|
|
|
|
|
CFG = Config()
|
|
|
|
|
|
class MilvusStore(VectorStoreBase):
|
|
"""Milvus database"""
|
|
|
|
def __init__(self, ctx: {}) -> None:
|
|
"""init a milvus storage connection.
|
|
|
|
Args:
|
|
ctx ({}): MilvusStore global config.
|
|
"""
|
|
# self.configure(cfg)
|
|
|
|
connect_kwargs = {}
|
|
self.uri = CFG.MILVUS_URL
|
|
self.port = CFG.MILVUS_PORT
|
|
self.username = CFG.MILVUS_USERNAME
|
|
self.password = CFG.MILVUS_PASSWORD
|
|
self.collection_name = ctx.get("vector_store_name", None)
|
|
self.secure = ctx.get("secure", None)
|
|
self.embedding = ctx.get("embeddings", None)
|
|
self.fields = []
|
|
self.alias = "default"
|
|
|
|
# use HNSW by default.
|
|
self.index_params = {
|
|
"metric_type": "L2",
|
|
"index_type": "HNSW",
|
|
"params": {"M": 8, "efConstruction": 64},
|
|
}
|
|
# use HNSW by default.
|
|
self.index_params_map = {
|
|
"IVF_FLAT": {"params": {"nprobe": 10}},
|
|
"IVF_SQ8": {"params": {"nprobe": 10}},
|
|
"IVF_PQ": {"params": {"nprobe": 10}},
|
|
"HNSW": {"params": {"ef": 10}},
|
|
"RHNSW_FLAT": {"params": {"ef": 10}},
|
|
"RHNSW_SQ": {"params": {"ef": 10}},
|
|
"RHNSW_PQ": {"params": {"ef": 10}},
|
|
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
|
"ANNOY": {"params": {"search_k": 10}},
|
|
}
|
|
# default collection schema
|
|
self.primary_field = "pk_id"
|
|
self.vector_field = "vector"
|
|
self.text_field = "content"
|
|
|
|
if (self.username is None) != (self.password is None):
|
|
raise ValueError(
|
|
"Both username and password must be set to use authentication for Milvus"
|
|
)
|
|
if self.username:
|
|
connect_kwargs["user"] = self.username
|
|
connect_kwargs["password"] = self.password
|
|
|
|
connections.connect(
|
|
host=self.uri or "127.0.0.1",
|
|
port=self.port or "19530",
|
|
alias="default"
|
|
# secure=self.secure,
|
|
)
|
|
|
|
def init_schema_and_load(self, vector_name, documents):
|
|
"""Create a Milvus collection, indexes it with HNSW, load document.
|
|
Args:
|
|
vector_name (Embeddings): your collection name.
|
|
documents (List[str]): Text to insert.
|
|
Returns:
|
|
VectorStore: The MilvusStore vector store.
|
|
"""
|
|
try:
|
|
from pymilvus import (
|
|
Collection,
|
|
CollectionSchema,
|
|
DataType,
|
|
FieldSchema,
|
|
connections,
|
|
)
|
|
from pymilvus.orm.types import infer_dtype_bydata
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import pymilvus python package. "
|
|
"Please install it with `pip install pymilvus`."
|
|
)
|
|
if not connections.has_connection("default"):
|
|
connections.connect(
|
|
host=self.uri or "127.0.0.1",
|
|
port=self.port or "19530",
|
|
alias="default"
|
|
# secure=self.secure,
|
|
)
|
|
texts = [d.page_content for d in documents]
|
|
metadatas = [d.metadata for d in documents]
|
|
embeddings = self.embedding.embed_query(texts[0])
|
|
|
|
if utility.has_collection(self.collection_name):
|
|
self.col = Collection(self.collection_name, using=self.alias)
|
|
self.fields = []
|
|
for x in self.col.schema.fields:
|
|
self.fields.append(x.name)
|
|
if x.auto_id:
|
|
self.fields.remove(x.name)
|
|
if x.is_primary:
|
|
self.primary_field = x.name
|
|
if (
|
|
x.dtype == DataType.FLOAT_VECTOR
|
|
or x.dtype == DataType.BINARY_VECTOR
|
|
):
|
|
self.vector_field = x.name
|
|
self._add_documents(texts, metadatas)
|
|
return self.collection_name
|
|
|
|
dim = len(embeddings)
|
|
# Generate unique names
|
|
primary_field = self.primary_field
|
|
vector_field = self.vector_field
|
|
text_field = self.text_field
|
|
# self.text_field = text_field
|
|
collection_name = vector_name
|
|
fields = []
|
|
max_length = 0
|
|
for y in texts:
|
|
max_length = max(max_length, len(y))
|
|
# Create the text field
|
|
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
|
|
# primary key field
|
|
fields.append(
|
|
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
|
|
)
|
|
# vector field
|
|
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
|
schema = CollectionSchema(fields)
|
|
# Create the collection
|
|
collection = Collection(collection_name, schema)
|
|
self.col = collection
|
|
# index parameters for the collection
|
|
index = self.index_params
|
|
# milvus index
|
|
collection.create_index(vector_field, index)
|
|
schema = collection.schema
|
|
for x in schema.fields:
|
|
self.fields.append(x.name)
|
|
if x.auto_id:
|
|
self.fields.remove(x.name)
|
|
if x.is_primary:
|
|
self.primary_field = x.name
|
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
|
self.vector_field = x.name
|
|
self._add_documents(texts, metadatas)
|
|
|
|
return self.collection_name
|
|
|
|
# def init_schema(self) -> None:
|
|
# """Initialize collection in milvus database."""
|
|
# fields = [
|
|
# FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
|
# FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.model_config["dim"]),
|
|
# FieldSchema(name="raw_text", dtype=DataType.VARCHAR, max_length=65535),
|
|
# ]
|
|
#
|
|
# # create collection if not exist and load it.
|
|
# self.schema = CollectionSchema(fields, "db-gpt memory storage")
|
|
# self.collection = Collection(self.collection_name, self.schema)
|
|
# self.index_params_map = {
|
|
# "IVF_FLAT": {"params": {"nprobe": 10}},
|
|
# "IVF_SQ8": {"params": {"nprobe": 10}},
|
|
# "IVF_PQ": {"params": {"nprobe": 10}},
|
|
# "HNSW": {"params": {"ef": 10}},
|
|
# "RHNSW_FLAT": {"params": {"ef": 10}},
|
|
# "RHNSW_SQ": {"params": {"ef": 10}},
|
|
# "RHNSW_PQ": {"params": {"ef": 10}},
|
|
# "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
|
# "ANNOY": {"params": {"search_k": 10}},
|
|
# }
|
|
#
|
|
# self.index_params = {
|
|
# "metric_type": "IP",
|
|
# "index_type": "HNSW",
|
|
# "params": {"M": 8, "efConstruction": 64},
|
|
# }
|
|
# # create index if not exist.
|
|
# if not self.collection.has_index():
|
|
# self.collection.release()
|
|
# self.collection.create_index(
|
|
# "vector",
|
|
# self.index_params,
|
|
# index_name="vector",
|
|
# )
|
|
# info = self.collection.describe()
|
|
# self.collection.load()
|
|
|
|
# def insert(self, text, model_config) -> str:
|
|
# """Add an embedding of data into milvus.
|
|
# Args:
|
|
# text (str): The raw text to construct embedding index.
|
|
# Returns:
|
|
# str: log.
|
|
# """
|
|
# # embedding = get_ada_embedding(data)
|
|
# embeddings = HuggingFaceEmbeddings(model_name=self.model_config["model_name"])
|
|
# result = self.collection.insert([embeddings.embed_documents(text), text])
|
|
# _text = (
|
|
# "Inserting data into memory at primary key: "
|
|
# f"{result.primary_keys[0]}:\n data: {text}"
|
|
# )
|
|
# return _text
|
|
|
|
def _add_documents(
|
|
self,
|
|
texts: Iterable[str],
|
|
metadatas: Optional[List[dict]] = None,
|
|
partition_name: Optional[str] = None,
|
|
timeout: Optional[int] = None,
|
|
) -> List[str]:
|
|
"""add text data into Milvus."""
|
|
insert_dict: Any = {self.text_field: list(texts)}
|
|
try:
|
|
insert_dict[self.vector_field] = self.embedding.embed_documents(list(texts))
|
|
except NotImplementedError:
|
|
insert_dict[self.vector_field] = [
|
|
self.embedding.embed_query(x) for x in texts
|
|
]
|
|
# Collect the metadata into the insert dict.
|
|
if len(self.fields) > 2 and metadatas is not None:
|
|
for d in metadatas:
|
|
for key, value in d.items():
|
|
if key in self.fields:
|
|
insert_dict.setdefault(key, []).append(value)
|
|
# Convert dict to list of lists for insertion
|
|
insert_list = [insert_dict[x] for x in self.fields]
|
|
# Insert into the collection.
|
|
res = self.col.insert(
|
|
insert_list, partition_name=partition_name, timeout=timeout
|
|
)
|
|
# make sure data is searchable.
|
|
self.col.flush()
|
|
return res.primary_keys
|
|
|
|
def load_document(self, documents) -> None:
|
|
"""load document in vector database."""
|
|
# self.init_schema_and_load(self.collection_name, documents)
|
|
batch_size = 500
|
|
batched_list = [
|
|
documents[i : i + batch_size] for i in range(0, len(documents), batch_size)
|
|
]
|
|
# docs = []
|
|
for doc_batch in batched_list:
|
|
self.init_schema_and_load(self.collection_name, doc_batch)
|
|
|
|
def similar_search(self, text, topk) -> None:
|
|
"""similar_search in vector database."""
|
|
self.col = Collection(self.collection_name)
|
|
schema = self.col.schema
|
|
for x in schema.fields:
|
|
self.fields.append(x.name)
|
|
if x.auto_id:
|
|
self.fields.remove(x.name)
|
|
if x.is_primary:
|
|
self.primary_field = x.name
|
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
|
self.vector_field = x.name
|
|
_, docs_and_scores = self._search(text, topk)
|
|
return [doc for doc, _, _ in docs_and_scores]
|
|
|
|
def _search(
|
|
self,
|
|
query: str,
|
|
k: int = 4,
|
|
param: Optional[dict] = None,
|
|
expr: Optional[str] = None,
|
|
partition_names: Optional[List[str]] = None,
|
|
round_decimal: int = -1,
|
|
timeout: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]:
|
|
self.col.load()
|
|
# use default index params.
|
|
if param is None:
|
|
index_type = self.col.indexes[0].params["index_type"]
|
|
param = self.index_params_map[index_type]
|
|
# query text embedding.
|
|
data = [self.embedding.embed_query(query)]
|
|
# Determine result metadata fields.
|
|
output_fields = self.fields[:]
|
|
output_fields.remove(self.vector_field)
|
|
# milvus search.
|
|
res = self.col.search(
|
|
data,
|
|
self.vector_field,
|
|
param,
|
|
k,
|
|
expr=expr,
|
|
output_fields=output_fields,
|
|
partition_names=partition_names,
|
|
round_decimal=round_decimal,
|
|
timeout=timeout,
|
|
**kwargs,
|
|
)
|
|
ret = []
|
|
for result in res[0]:
|
|
meta = {x: result.entity.get(x) for x in output_fields}
|
|
ret.append(
|
|
(
|
|
Document(page_content=meta.pop(self.text_field), metadata=meta),
|
|
result.distance,
|
|
result.id,
|
|
)
|
|
)
|
|
|
|
return data[0], ret
|
|
|
|
def close(self):
|
|
connections.disconnect()
|