fix:milvus metadata bug

Close #800
This commit is contained in:
aries_ckt 2023-11-16 23:22:14 +08:00
parent 3e560099b4
commit 16a995174b

View File

@ -1,4 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import os import os
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
@ -30,7 +32,7 @@ class MilvusStore(VectorStoreBase):
self.secure = ctx.get("MILVUS_SECURE", os.getenv("MILVUS_SECURE")) self.secure = ctx.get("MILVUS_SECURE", os.getenv("MILVUS_SECURE"))
self.collection_name = ctx.get("vector_store_name", None) self.collection_name = ctx.get("vector_store_name", None)
self.embedding = ctx.get("embeddings", None) self.embedding = ctx.get("embeddings", None)
self.fields = [] self.fields = ["metadata"]
self.alias = "default" self.alias = "default"
# use HNSW by default. # use HNSW by default.
@ -55,6 +57,7 @@ class MilvusStore(VectorStoreBase):
self.primary_field = "pk_id" self.primary_field = "pk_id"
self.vector_field = "vector" self.vector_field = "vector"
self.text_field = "content" self.text_field = "content"
self.metadata_field = "metadata"
if (self.username is None) != (self.password is None): if (self.username is None) != (self.password is None):
raise ValueError( raise ValueError(
@ -127,6 +130,7 @@ class MilvusStore(VectorStoreBase):
primary_field = self.primary_field primary_field = self.primary_field
vector_field = self.vector_field vector_field = self.vector_field
text_field = self.text_field text_field = self.text_field
metadata_field = self.metadata_field
# self.text_field = text_field # self.text_field = text_field
collection_name = vector_name collection_name = vector_name
fields = [] fields = []
@ -141,6 +145,8 @@ class MilvusStore(VectorStoreBase):
) )
# vector field # vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
schema = CollectionSchema(fields) schema = CollectionSchema(fields)
# Create the collection # Create the collection
collection = Collection(collection_name, schema) collection = Collection(collection_name, schema)
@ -233,11 +239,11 @@ class MilvusStore(VectorStoreBase):
self.embedding.embed_query(x) for x in texts self.embedding.embed_query(x) for x in texts
] ]
# Collect the metadata into the insert dict. # Collect the metadata into the insert dict.
# self.fields.extend(metadatas[0].keys())
if len(self.fields) > 2 and metadatas is not None: if len(self.fields) > 2 and metadatas is not None:
for d in metadatas: for d in metadatas:
for key, value in d.items(): # for key, value in d.items():
if key in self.fields: insert_dict.setdefault("metadata", []).append(json.dumps(d))
insert_dict.setdefault(key, []).append(value)
# Convert dict to list of lists for insertion # Convert dict to list of lists for insertion
insert_list = [insert_dict[x] for x in self.fields] insert_list = [insert_dict[x] for x in self.fields]
# Insert into the collection. # Insert into the collection.
@ -261,7 +267,7 @@ class MilvusStore(VectorStoreBase):
doc_ids = [str(doc_id) for doc_id in doc_ids] doc_ids = [str(doc_id) for doc_id in doc_ids]
return doc_ids return doc_ids
def similar_search(self, text, topk) -> None: def similar_search(self, text, topk):
from pymilvus import Collection, DataType from pymilvus import Collection, DataType
"""similar_search in vector database.""" """similar_search in vector database."""
@ -276,7 +282,16 @@ class MilvusStore(VectorStoreBase):
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name self.vector_field = x.name
_, docs_and_scores = self._search(text, topk) _, docs_and_scores = self._search(text, topk)
return [doc for doc, _, _ in docs_and_scores] from langchain.schema import Document
return [
Document(
metadata=json.loads(doc.metadata.get("metadata", "")),
page_content=doc.page_content,
)
for doc, _, _ in docs_and_scores
]
# return [doc for doc, _, _ in docs_and_scores]
def _search( def _search(
self, self,