fix:milvus metadata bug

Close #800
This commit is contained in:
aries_ckt 2023-11-16 23:22:14 +08:00
parent 3e560099b4
commit 16a995174b

View File

@ -1,4 +1,6 @@
from __future__ import annotations
import json
import logging
import os
from typing import Any, Iterable, List, Optional, Tuple
@ -30,7 +32,7 @@ class MilvusStore(VectorStoreBase):
self.secure = ctx.get("MILVUS_SECURE", os.getenv("MILVUS_SECURE"))
self.collection_name = ctx.get("vector_store_name", None)
self.embedding = ctx.get("embeddings", None)
self.fields = []
self.fields = ["metadata"]
self.alias = "default"
# use HNSW by default.
@ -55,6 +57,7 @@ class MilvusStore(VectorStoreBase):
self.primary_field = "pk_id"
self.vector_field = "vector"
self.text_field = "content"
self.metadata_field = "metadata"
if (self.username is None) != (self.password is None):
raise ValueError(
@ -127,6 +130,7 @@ class MilvusStore(VectorStoreBase):
primary_field = self.primary_field
vector_field = self.vector_field
text_field = self.text_field
metadata_field = self.metadata_field
# self.text_field = text_field
collection_name = vector_name
fields = []
@ -141,6 +145,8 @@ class MilvusStore(VectorStoreBase):
)
# vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535))
schema = CollectionSchema(fields)
# Create the collection
collection = Collection(collection_name, schema)
@ -233,11 +239,11 @@ class MilvusStore(VectorStoreBase):
self.embedding.embed_query(x) for x in texts
]
# Collect the metadata into the insert dict.
# self.fields.extend(metadatas[0].keys())
if len(self.fields) > 2 and metadatas is not None:
for d in metadatas:
for key, value in d.items():
if key in self.fields:
insert_dict.setdefault(key, []).append(value)
# for key, value in d.items():
insert_dict.setdefault("metadata", []).append(json.dumps(d))
# Convert dict to list of lists for insertion
insert_list = [insert_dict[x] for x in self.fields]
# Insert into the collection.
@ -261,7 +267,7 @@ class MilvusStore(VectorStoreBase):
doc_ids = [str(doc_id) for doc_id in doc_ids]
return doc_ids
def similar_search(self, text, topk) -> None:
def similar_search(self, text, topk):
from pymilvus import Collection, DataType
"""similar_search in vector database."""
@ -276,7 +282,16 @@ class MilvusStore(VectorStoreBase):
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
_, docs_and_scores = self._search(text, topk)
return [doc for doc, _, _ in docs_and_scores]
from langchain.schema import Document
return [
Document(
metadata=json.loads(doc.metadata.get("metadata", "")),
page_content=doc.page_content,
)
for doc, _, _ in docs_and_scores
]
# return [doc for doc, _, _ in docs_and_scores]
def _search(
self,