diff --git a/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py b/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py index 0b72190876f..01db46c99f9 100644 --- a/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py +++ b/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py @@ -240,6 +240,7 @@ class Milvus(VectorStore): replica_number: int = 1, timeout: Optional[float] = None, num_shards: Optional[int] = None, + metadata_schema: Optional[dict[str, Any]] = None, ): """Initialize the Milvus vector store.""" try: @@ -310,6 +311,7 @@ class Milvus(VectorStore): self.replica_number = replica_number self.timeout = timeout self.num_shards = num_shards + self.metadata_schema = metadata_schema # Create the connection to the server if connection_args is None: @@ -472,24 +474,47 @@ class Milvus(VectorStore): ) raise ValueError(f"Metadata key {key} is reserved.") # Infer the corresponding datatype of the metadata - dtype = infer_dtype_bydata(value) - # Datatype isn't compatible - if dtype == DataType.UNKNOWN or dtype == DataType.NONE: - logger.error( - ( - "Failure to create collection, " - "unrecognized dtype for key: %s" - ), - key, - ) - raise ValueError(f"Unrecognized datatype for {key}.") - # Datatype is a string/varchar equivalent - elif dtype == DataType.VARCHAR: + if ( + key in self.metadata_schema # type: ignore + and "dtype" in self.metadata_schema[key] # type: ignore + ): + kwargs = self.metadata_schema[key].get("kwargs", {}) # type: ignore fields.append( - FieldSchema(key, DataType.VARCHAR, max_length=65_535) + FieldSchema( + name=key, + dtype=self.metadata_schema[key]["dtype"], # type: ignore + **kwargs, + ) ) else: - fields.append(FieldSchema(key, dtype)) + dtype = infer_dtype_bydata(value) + # Datatype isn't compatible + if dtype == DataType.UNKNOWN or dtype == DataType.NONE: + logger.error( + ( + "Failure to create collection, " + "unrecognized dtype for key: %s" + ), + key, + ) + raise ValueError(f"Unrecognized datatype for {key}.") + # Datatype is a string/varchar equivalent + elif dtype == DataType.VARCHAR: + fields.append( + FieldSchema(key, DataType.VARCHAR, max_length=65_535) + ) + # infer_dtype_bydata currently can't recognize array type, + # so this line can not be accessed. + # This line may need to be modified in the future when + # infer_dtype_bydata can recognize array type. + # https://github.com/milvus-io/pymilvus/issues/2165 + elif dtype == DataType.ARRAY: + kwargs = self.metadata_schema[key]["kwargs"] # type: ignore + fields.append( + FieldSchema(name=key, dtype=DataType.ARRAY, **kwargs) + ) + else: + fields.append(FieldSchema(key, dtype)) # Create the text field fields.append( diff --git a/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py b/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py index 7e9b8f76f9a..2066f61c56d 100644 --- a/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py +++ b/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py @@ -39,6 +39,7 @@ def _milvus_from_texts( # connection_args={"uri": "http://127.0.0.1:19530"}, connection_args={"uri": "./milvus_demo.db"}, drop_old=drop, + consistency_level="Strong", **kwargs, ) @@ -303,6 +304,51 @@ def test_milvus_enable_dynamic_field_with_partition_key() -> None: } +def test_milvus_array_field() -> None: + """Manually specify metadata schema, including an array_field. + For more information about array data type and filtering, please refer to + https://milvus.io/docs/array_data_type.md + """ + from pymilvus import DataType + + texts = ["foo", "bar", "baz"] + metadatas = [{"id": i, "array_field": [i, i + 1, i + 2]} for i in range(len(texts))] + + # Manually specify metadata schema, including an array_field. + # If some fields are not specified, Milvus will automatically infer their schemas. + docsearch = _milvus_from_texts( + metadatas=metadatas, + metadata_schema={ + "array_field": { + "dtype": DataType.ARRAY, + "kwargs": {"element_type": DataType.INT64, "max_capacity": 50}, + }, + # "id": { + # "dtype": DataType.INT64, + # } + }, + ) + output = docsearch.similarity_search("foo", k=10, expr="array_field[0] < 2") + assert len(output) == 2 + output = docsearch.similarity_search( + "foo", k=10, expr="ARRAY_CONTAINS(array_field, 3)" + ) + assert len(output) == 2 + + # If we use enable_dynamic_field, + # there is no need to manually specify metadata schema. + docsearch = _milvus_from_texts( + enable_dynamic_field=True, + metadatas=metadatas, + ) + output = docsearch.similarity_search("foo", k=10, expr="array_field[0] < 2") + assert len(output) == 2 + output = docsearch.similarity_search( + "foo", k=10, expr="ARRAY_CONTAINS(array_field, 3)" + ) + assert len(output) == 2 + + # if __name__ == "__main__": # test_milvus() # test_milvus_vector_search() @@ -319,3 +365,4 @@ def test_milvus_enable_dynamic_field_with_partition_key() -> None: # test_milvus_enable_dynamic_field() # test_milvus_disable_dynamic_field() # test_milvus_metadata_field() +# test_milvus_array_field()