diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index ee304fe25..582289d80 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import json import logging import os from typing import Any, Iterable, List, Optional, Tuple @@ -30,7 +32,7 @@ class MilvusStore(VectorStoreBase): self.secure = ctx.get("MILVUS_SECURE", os.getenv("MILVUS_SECURE")) self.collection_name = ctx.get("vector_store_name", None) self.embedding = ctx.get("embeddings", None) - self.fields = [] + self.fields = ["metadata"] self.alias = "default" # use HNSW by default. @@ -55,6 +57,7 @@ class MilvusStore(VectorStoreBase): self.primary_field = "pk_id" self.vector_field = "vector" self.text_field = "content" + self.metadata_field = "metadata" if (self.username is None) != (self.password is None): raise ValueError( @@ -127,6 +130,7 @@ class MilvusStore(VectorStoreBase): primary_field = self.primary_field vector_field = self.vector_field text_field = self.text_field + metadata_field = self.metadata_field # self.text_field = text_field collection_name = vector_name fields = [] @@ -141,6 +145,8 @@ class MilvusStore(VectorStoreBase): ) # vector field fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) + + fields.append(FieldSchema(metadata_field, DataType.VARCHAR, max_length=65535)) schema = CollectionSchema(fields) # Create the collection collection = Collection(collection_name, schema) @@ -233,11 +239,11 @@ class MilvusStore(VectorStoreBase): self.embedding.embed_query(x) for x in texts ] # Collect the metadata into the insert dict. + # self.fields.extend(metadatas[0].keys()) if len(self.fields) > 2 and metadatas is not None: for d in metadatas: - for key, value in d.items(): - if key in self.fields: - insert_dict.setdefault(key, []).append(value) + # for key, value in d.items(): + insert_dict.setdefault("metadata", []).append(json.dumps(d)) # Convert dict to list of lists for insertion insert_list = [insert_dict[x] for x in self.fields] # Insert into the collection. @@ -261,7 +267,7 @@ class MilvusStore(VectorStoreBase): doc_ids = [str(doc_id) for doc_id in doc_ids] return doc_ids - def similar_search(self, text, topk) -> None: + def similar_search(self, text, topk): from pymilvus import Collection, DataType """similar_search in vector database.""" @@ -276,7 +282,16 @@ class MilvusStore(VectorStoreBase): if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: self.vector_field = x.name _, docs_and_scores = self._search(text, topk) - return [doc for doc, _, _ in docs_and_scores] + from langchain.schema import Document + + return [ + Document( + metadata=json.loads(doc.metadata.get("metadata", "")), + page_content=doc.page_content, + ) + for doc, _, _ in docs_and_scores + ] + # return [doc for doc, _, _ in docs_and_scores] def _search( self,