mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
update:milvus append knowledge
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Iterable, List, Optional, Tuple
|
from typing import Any, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from pymilvus import Collection, DataType, connections
|
from pymilvus import Collection, DataType, connections, utility
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.vector_store.vector_store_base import VectorStoreBase
|
from pilot.vector_store.vector_store_base import VectorStoreBase
|
||||||
@@ -29,6 +29,8 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.secure = ctx.get("secure", None)
|
self.secure = ctx.get("secure", None)
|
||||||
self.embedding = ctx.get("embeddings", None)
|
self.embedding = ctx.get("embeddings", None)
|
||||||
self.fields = []
|
self.fields = []
|
||||||
|
self.alias = "default"
|
||||||
|
|
||||||
|
|
||||||
# use HNSW by default.
|
# use HNSW by default.
|
||||||
self.index_params = {
|
self.index_params = {
|
||||||
@@ -48,7 +50,9 @@ class MilvusStore(VectorStoreBase):
|
|||||||
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
|
||||||
"ANNOY": {"params": {"search_k": 10}},
|
"ANNOY": {"params": {"search_k": 10}},
|
||||||
}
|
}
|
||||||
|
# default collection schema
|
||||||
|
self.primary_field = "pk_id"
|
||||||
|
self.vector_field = "vector"
|
||||||
self.text_field = "content"
|
self.text_field = "content"
|
||||||
|
|
||||||
if (self.username is None) != (self.password is None):
|
if (self.username is None) != (self.password is None):
|
||||||
@@ -98,42 +102,30 @@ class MilvusStore(VectorStoreBase):
|
|||||||
texts = [d.page_content for d in documents]
|
texts = [d.page_content for d in documents]
|
||||||
metadatas = [d.metadata for d in documents]
|
metadatas = [d.metadata for d in documents]
|
||||||
embeddings = self.embedding.embed_query(texts[0])
|
embeddings = self.embedding.embed_query(texts[0])
|
||||||
|
|
||||||
|
if utility.has_collection(self.collection_name):
|
||||||
|
self.col = Collection(
|
||||||
|
self.collection_name, using=self.alias
|
||||||
|
)
|
||||||
|
for x in self.col.schema.fields:
|
||||||
|
self.fields.append(x.name)
|
||||||
|
if x.auto_id:
|
||||||
|
self.fields.remove(x.name)
|
||||||
|
if x.is_primary:
|
||||||
|
self.primary_field = x.name
|
||||||
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||||
|
self.vector_field = x.name
|
||||||
|
self._add_documents(texts, metadatas)
|
||||||
|
return self.collection_name
|
||||||
|
|
||||||
dim = len(embeddings)
|
dim = len(embeddings)
|
||||||
# Generate unique names
|
# Generate unique names
|
||||||
primary_field = "pk_id"
|
primary_field = self.primary_field
|
||||||
vector_field = "vector"
|
vector_field = self.vector_field
|
||||||
text_field = "content"
|
text_field = self.text_field
|
||||||
self.text_field = text_field
|
# self.text_field = text_field
|
||||||
collection_name = vector_name
|
collection_name = vector_name
|
||||||
fields = []
|
fields = []
|
||||||
# Determine metadata schema
|
|
||||||
# if metadatas:
|
|
||||||
# # Check if all metadata keys line up
|
|
||||||
# key = metadatas[0].keys()
|
|
||||||
# for x in metadatas:
|
|
||||||
# if key != x.keys():
|
|
||||||
# raise ValueError(
|
|
||||||
# "Mismatched metadata. "
|
|
||||||
# "Make sure all metadata has the same keys and datatype."
|
|
||||||
# )
|
|
||||||
# # Create FieldSchema for each entry in singular metadata.
|
|
||||||
# for key, value in metadatas[0].items():
|
|
||||||
# # Infer the corresponding datatype of the metadata
|
|
||||||
# dtype = infer_dtype_bydata(value)
|
|
||||||
# if dtype == DataType.UNKNOWN:
|
|
||||||
# raise ValueError(f"Unrecognized datatype for {key}.")
|
|
||||||
# elif dtype == DataType.VARCHAR:
|
|
||||||
# # Find out max length text based metadata
|
|
||||||
# max_length = 0
|
|
||||||
# for subvalues in metadatas:
|
|
||||||
# max_length = max(max_length, len(subvalues[key]))
|
|
||||||
# fields.append(
|
|
||||||
# FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1)
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# fields.append(FieldSchema(key, dtype))
|
|
||||||
|
|
||||||
# Find out max length of texts
|
|
||||||
max_length = 0
|
max_length = 0
|
||||||
for y in texts:
|
for y in texts:
|
||||||
max_length = max(max_length, len(y))
|
max_length = max(max_length, len(y))
|
||||||
@@ -147,7 +139,6 @@ class MilvusStore(VectorStoreBase):
|
|||||||
)
|
)
|
||||||
# vector field
|
# vector field
|
||||||
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
|
||||||
# milvus the schema for the collection
|
|
||||||
schema = CollectionSchema(fields)
|
schema = CollectionSchema(fields)
|
||||||
# Create the collection
|
# Create the collection
|
||||||
collection = Collection(collection_name, schema)
|
collection = Collection(collection_name, schema)
|
||||||
@@ -165,7 +156,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
self.primary_field = x.name
|
self.primary_field = x.name
|
||||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||||
self.vector_field = x.name
|
self.vector_field = x.name
|
||||||
self._add_texts(texts, metadatas)
|
self._add_documents(texts, metadatas)
|
||||||
|
|
||||||
return self.collection_name
|
return self.collection_name
|
||||||
|
|
||||||
@@ -224,7 +215,7 @@ class MilvusStore(VectorStoreBase):
|
|||||||
# )
|
# )
|
||||||
# return _text
|
# return _text
|
||||||
|
|
||||||
def _add_texts(
|
def _add_documents(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
@@ -320,3 +311,6 @@ class MilvusStore(VectorStoreBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return data[0], ret
|
return data[0], ret
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
connections.disconnect()
|
Reference in New Issue
Block a user