update:milvus append knowledge

This commit is contained in:
aries-ckt
2023-05-25 14:39:14 +08:00
parent d3567fb984
commit 85906a3c45

View File

@@ -1,7 +1,7 @@
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document from langchain.docstore.document import Document
from pymilvus import Collection, DataType, connections from pymilvus import Collection, DataType, connections, utility
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.vector_store.vector_store_base import VectorStoreBase from pilot.vector_store.vector_store_base import VectorStoreBase
@@ -29,6 +29,8 @@ class MilvusStore(VectorStoreBase):
self.secure = ctx.get("secure", None) self.secure = ctx.get("secure", None)
self.embedding = ctx.get("embeddings", None) self.embedding = ctx.get("embeddings", None)
self.fields = [] self.fields = []
self.alias = "default"
# use HNSW by default. # use HNSW by default.
self.index_params = { self.index_params = {
@@ -48,7 +50,9 @@ class MilvusStore(VectorStoreBase):
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"params": {"search_k": 10}}, "ANNOY": {"params": {"search_k": 10}},
} }
# default collection schema
self.primary_field = "pk_id"
self.vector_field = "vector"
self.text_field = "content" self.text_field = "content"
if (self.username is None) != (self.password is None): if (self.username is None) != (self.password is None):
@@ -98,42 +102,30 @@ class MilvusStore(VectorStoreBase):
texts = [d.page_content for d in documents] texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents] metadatas = [d.metadata for d in documents]
embeddings = self.embedding.embed_query(texts[0]) embeddings = self.embedding.embed_query(texts[0])
if utility.has_collection(self.collection_name):
self.col = Collection(
self.collection_name, using=self.alias
)
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
self._add_documents(texts, metadatas)
return self.collection_name
dim = len(embeddings) dim = len(embeddings)
# Generate unique names # Generate unique names
primary_field = "pk_id" primary_field = self.primary_field
vector_field = "vector" vector_field = self.vector_field
text_field = "content" text_field = self.text_field
self.text_field = text_field # self.text_field = text_field
collection_name = vector_name collection_name = vector_name
fields = [] fields = []
# Determine metadata schema
# if metadatas:
# # Check if all metadata keys line up
# key = metadatas[0].keys()
# for x in metadatas:
# if key != x.keys():
# raise ValueError(
# "Mismatched metadata. "
# "Make sure all metadata has the same keys and datatype."
# )
# # Create FieldSchema for each entry in singular metadata.
# for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
# dtype = infer_dtype_bydata(value)
# if dtype == DataType.UNKNOWN:
# raise ValueError(f"Unrecognized datatype for {key}.")
# elif dtype == DataType.VARCHAR:
# # Find out max length text based metadata
# max_length = 0
# for subvalues in metadatas:
# max_length = max(max_length, len(subvalues[key]))
# fields.append(
# FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1)
# )
# else:
# fields.append(FieldSchema(key, dtype))
# Find out max length of texts
max_length = 0 max_length = 0
for y in texts: for y in texts:
max_length = max(max_length, len(y)) max_length = max(max_length, len(y))
@@ -147,7 +139,6 @@ class MilvusStore(VectorStoreBase):
) )
# vector field # vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
# milvus the schema for the collection
schema = CollectionSchema(fields) schema = CollectionSchema(fields)
# Create the collection # Create the collection
collection = Collection(collection_name, schema) collection = Collection(collection_name, schema)
@@ -165,7 +156,7 @@ class MilvusStore(VectorStoreBase):
self.primary_field = x.name self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name self.vector_field = x.name
self._add_texts(texts, metadatas) self._add_documents(texts, metadatas)
return self.collection_name return self.collection_name
@@ -224,7 +215,7 @@ class MilvusStore(VectorStoreBase):
# ) # )
# return _text # return _text
def _add_texts( def _add_documents(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@@ -320,3 +311,6 @@ class MilvusStore(VectorStoreBase):
) )
return data[0], ret return data[0], ret
def close(self):
connections.disconnect()