diff --git a/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py b/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py index 343a92fa68c..c33af2bfcd7 100644 --- a/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py +++ b/libs/partners/milvus/langchain_milvus/vectorstores/milvus.py @@ -130,24 +130,46 @@ class Milvus(VectorStore): drop_old (Optional[bool]): Whether to drop the current collection. Defaults to False. auto_id (bool): Whether to enable auto id for primary key. Defaults to False. - If False, you needs to provide text ids (string less than 65535 bytes). + If False, you need to provide text ids (string less than 65535 bytes). If True, Milvus will generate unique integers as primary keys. primary_field (str): Name of the primary key field. Defaults to "pk". text_field (str): Name of the text field. Defaults to "text". vector_field (str): Name of the vector field. Defaults to "vector". - metadata_field (str): Name of the metadta field. Defaults to None. + enable_dynamic_field (Optional[bool]): Whether to enable + dynamic schema or not in Milvus. Defaults to False. + For more information about dynamic schema, please refer to + https://milvus.io/docs/enable-dynamic-field.md + metadata_field (str): Name of the metadata field. Defaults to None. When metadata_field is specified, the document's metadata will store as json. + This argument is about to be deprecated, + because it can be replaced by setting `enable_dynamic_field`=True. + partition_key_field (Optional[str]): Name of the partition key field. + Defaults to None. For more information about partition key, please refer to + https://milvus.io/docs/use-partition-key.md#Use-Partition-Key + partition_names (Optional[list]): List of specific partition names. + Defaults to None. For more information about partition, please refer to + https://milvus.io/docs/manage-partitions.md#Manage-Partitions + replica_number (int): The number of replicas for the collection. Defaults to 1. + For more information about replica, please refer to + https://milvus.io/docs/replica.md#In-Memory-Replica + timeout (Optional[float]): The timeout for Milvus operations. Defaults to None. + An optional duration of time in seconds to allow for the RPCs. + If timeout is not set, the client keeps waiting until the server responds + or an error occurs. + num_shards (Optional[int]): The number of shards for the collection. + Defaults to None. For more information about shards, please refer to + https://milvus.io/docs/glossary.md#Shard The connection args used for this class comes in the form of a dict, here are a few of the options: address (str): The actual address of Milvus instance. Example address: "localhost:19530" uri (str): The uri of Milvus instance. Example uri: + "path/to/local/directory/milvus_demo.db" for Milvus Lite. "http://randomwebsite:19530", "tcp:foobarsite:19530", "https://ok.s3.south.com:19530". - or "path/to/local/directory/milvus_demo.db" for Milvus Lite. host (str): The host of Milvus instance. Default at "localhost", PyMilvus will fill in the default host if only port is provided. port (str/int): The port of Milvus instance. Default at 19530, PyMilvus @@ -178,6 +200,7 @@ class Milvus(VectorStore): milvus_store = Milvus( embedding_function = Embeddings, collection_name = "LangChainCollection", + connection_args = {"uri": "./milvus_demo.db"}, drop_old = True, auto_id = True ) @@ -202,6 +225,7 @@ class Milvus(VectorStore): primary_field: str = "pk", text_field: str = "text", vector_field: str = "vector", + enable_dynamic_field: bool = False, metadata_field: Optional[str] = None, partition_key_field: Optional[str] = None, partition_names: Optional[list] = None, @@ -260,6 +284,17 @@ class Milvus(VectorStore): self._text_field = text_field # In order for compatibility, the vector field needs to be called "vector" self._vector_field = vector_field + if metadata_field: + logger.warning( + "DeprecationWarning: `metadata_field` is about to be deprecated, " + "please set `enable_dynamic_field`=True instead." + ) + if enable_dynamic_field and metadata_field: + metadata_field = None + logger.warning( + "When `enable_dynamic_field` is True, `metadata_field` is ignored." + ) + self.enable_dynamic_field = enable_dynamic_field self._metadata_field = metadata_field self._partition_key_field = partition_key_field self.fields: list[str] = [] @@ -389,13 +424,36 @@ class Milvus(VectorStore): # Determine embedding dim dim = len(embeddings[0]) fields = [] - if self._metadata_field is not None: + # If enable_dynamic_field, we don't need to create fields, and just pass it. + # In the future, when metadata_field is deprecated, + # This logical structure will be simplified like this: + # ``` + # if not self.enable_dynamic_field and metadatas: + # for key, value in metadatas[0].items(): + # ... + # ``` + if self.enable_dynamic_field: + pass + elif self._metadata_field is not None: fields.append(FieldSchema(self._metadata_field, DataType.JSON)) else: # Determine metadata schema if metadatas: # Create FieldSchema for each entry in metadata. for key, value in metadatas[0].items(): + if key in [ + self._vector_field, + self._primary_field, + self._text_field, + ]: + logger.error( + ( + "Failure to create collection, " + "metadata key: %s is reserved." + ), + key, + ) + raise ValueError(f"Metadata key {key} is reserved.") # Infer the corresponding datatype of the metadata dtype = infer_dtype_bydata(value) # Datatype isn't compatible @@ -408,7 +466,7 @@ class Milvus(VectorStore): key, ) raise ValueError(f"Unrecognized datatype for {key}.") - # Dataype is a string/varchar equivalent + # Datatype is a string/varchar equivalent elif dtype == DataType.VARCHAR: fields.append( FieldSchema(key, DataType.VARCHAR, max_length=65_535) @@ -447,6 +505,7 @@ class Milvus(VectorStore): fields, description=self.collection_description, partition_key_field=self._partition_key_field, + enable_dynamic_field=self.enable_dynamic_field, ) # Create the collection @@ -617,16 +676,26 @@ class Milvus(VectorStore): texts = list(texts) if not self.auto_id: - assert isinstance( - ids, list - ), "A list of valid ids are required when auto_id is False." + assert isinstance(ids, list), ( + "A list of valid ids are required when auto_id is False. " + "You can set `auto_id` to True in this Milvus instance to generate " + "ids automatically, or specify string-type ids for each text." + ) assert len(set(ids)) == len( texts ), "Different lengths of texts and unique ids are provided." + assert all(isinstance(x, str) for x in ids), "All ids should be strings." assert all( len(x.encode()) <= 65_535 for x in ids ), "Each id should be a string less than 65535 bytes." + else: + if ids is not None: + logger.warning( + "The ids parameter is ignored when auto_id is True. " + "The ids will be generated automatically." + ) + try: embeddings = self.embedding_func.embed_documents(texts) except NotImplementedError: @@ -647,34 +716,39 @@ class Milvus(VectorStore): kwargs["timeout"] = self.timeout self._init(**kwargs) - # Dict to hold all insert columns - insert_dict: dict[str, list] = { - self._text_field: texts, - self._vector_field: embeddings, - } + insert_list: list[dict] = [] - if not self.auto_id: - insert_dict[self._primary_field] = ids # type: ignore[assignment] + assert len(texts) == len( + embeddings + ), "Mismatched lengths of texts and embeddings." + if metadatas is not None: + assert len(texts) == len( + metadatas + ), "Mismatched lengths of texts and metadatas." - if self._metadata_field is not None: - for d in metadatas: # type: ignore[union-attr] - insert_dict.setdefault(self._metadata_field, []).append(d) - else: - # Collect the metadata into the insert dict. - if metadatas is not None: - for d in metadatas: - for key, value in d.items(): - keys = ( - [x for x in self.fields if x != self._primary_field] - if self.auto_id - else [x for x in self.fields] - ) - if key in keys: - insert_dict.setdefault(key, []).append(value) + for i, text, embedding in zip(range(len(texts)), texts, embeddings): + entity_dict = {} + metadata = metadatas[i] if metadatas else {} + if not self.auto_id: + entity_dict[self._primary_field] = ids[i] # type: ignore[index] + + entity_dict[self._text_field] = text + entity_dict[self._vector_field] = embedding + + if self._metadata_field and not self.enable_dynamic_field: + entity_dict[self._metadata_field] = metadata + else: + for key, value in metadata.items(): + # if not enable_dynamic_field, skip fields not in the collection. + if not self.enable_dynamic_field and key not in self.fields: + continue + # If enable_dynamic_field, all fields are allowed. + entity_dict[key] = value + + insert_list.append(entity_dict) # Total insert count - vectors: list = insert_dict[self._vector_field] - total_count = len(vectors) + total_count = len(insert_list) pks: list[str] = [] @@ -682,15 +756,12 @@ class Milvus(VectorStore): for i in range(0, total_count, batch_size): # Grab end index end = min(i + batch_size, total_count) - # Convert dict to list of lists batch for insertion - insert_list = [ - insert_dict[x][i:end] for x in self.fields if x in insert_dict - ] + batch_insert_list = insert_list[i:end] # Insert into the collection. try: res: Collection timeout = self.timeout or timeout - res = self.col.insert(insert_list, timeout=timeout, **kwargs) + res = self.col.insert(batch_insert_list, timeout=timeout, **kwargs) pks.extend(res.primary_keys) except MilvusException as e: logger.error( @@ -699,6 +770,61 @@ class Milvus(VectorStore): raise e return pks + def _collection_search( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[float] = None, + **kwargs: Any, + ) -> "pymilvus.client.abstract.SearchResult | None": # type: ignore[name-defined] # noqa: F821 + """Perform a search on an embedding and return milvus search results. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md + + Args: + embedding (List[float]): The embedding vector being searched. + k (int, optional): The amount of results to return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (float, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + pymilvus.client.abstract.SearchResult: Milvus search result. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return None + + if param is None: + param = self.search_params + + # Determine result metadata fields with PK. + if self.enable_dynamic_field: + output_fields = ["*"] + else: + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + timeout = self.timeout or timeout + # Perform the search. + res = self.col.search( + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=k, + expr=expr, + output_fields=output_fields, + timeout=timeout, + **kwargs, + ) + return res + def similarity_search( self, query: str, @@ -778,7 +904,7 @@ class Milvus(VectorStore): For more information about the search parameters, take a look at the pymilvus documentation found here: - https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md Args: query (str): The text being searched. @@ -814,11 +940,11 @@ class Milvus(VectorStore): timeout: Optional[float] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: - """Perform a search on a query string and return results with score. + """Perform a search on an embedding and return results with score. For more information about the search parameters, take a look at the pymilvus documentation found here: - https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md Args: embedding (List[float]): The embedding vector being searched. @@ -833,32 +959,14 @@ class Milvus(VectorStore): Returns: List[Tuple[Document, float]]: Result doc and score. """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - if param is None: - param = self.search_params - - # Determine result metadata fields with PK. - output_fields = self.fields[:] - output_fields.remove(self._vector_field) - timeout = self.timeout or timeout - # Perform the search. - res = self.col.search( - data=[embedding], - anns_field=self._vector_field, - param=param, - limit=k, - expr=expr, - output_fields=output_fields, - timeout=timeout, - **kwargs, + col_search_res = self._collection_search( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs ) - # Organize results. + if col_search_res is None: + return [] ret = [] - for result in res[0]: - data = {x: result.entity.get(x) for x in output_fields} + for result in col_search_res[0]: + data = {x: result.entity.get(x) for x in result.entity.fields} doc = self._parse_document(data) pair = (doc, result.score) ret.append(pair) @@ -947,40 +1055,27 @@ class Milvus(VectorStore): Returns: List[Document]: Document results for search. """ - if self.col is None: - logger.debug("No existing collection to search.") - return [] - - if param is None: - param = self.search_params - - # Determine result metadata fields. - output_fields = self.fields[:] - output_fields.remove(self._vector_field) - timeout = self.timeout or timeout - # Perform the search. - res = self.col.search( - data=[embedding], - anns_field=self._vector_field, + col_search_res = self._collection_search( + embedding=embedding, + k=fetch_k, param=param, - limit=fetch_k, expr=expr, - output_fields=output_fields, timeout=timeout, **kwargs, ) - # Organize results. + if col_search_res is None: + return [] ids = [] documents = [] scores = [] - for result in res[0]: - data = {x: result.entity.get(x) for x in output_fields} + for result in col_search_res[0]: + data = {x: result.entity.get(x) for x in result.entity.fields} doc = self._parse_document(data) documents.append(doc) scores.append(result.score) ids.append(result.id) - vectors = self.col.query( + vectors = self.col.query( # type: ignore[union-attr] expr=f"{self._primary_field} in {ids}", output_fields=[self._primary_field, self._vector_field], timeout=timeout, @@ -1089,6 +1184,8 @@ class Milvus(VectorStore): return vector_db def _parse_document(self, data: dict) -> Document: + if self._vector_field in data: + data.pop(self._vector_field) return Document( page_content=data.pop(self._text_field), metadata=data.pop(self._metadata_field) if self._metadata_field else data, diff --git a/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py b/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py index 5eaf2abcc90..a565cac81d4 100644 --- a/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py +++ b/libs/partners/milvus/tests/integration_tests/vectorstores/test_milvus.py @@ -1,6 +1,7 @@ """Test Milvus functionality.""" from typing import Any, List, Optional +import pytest from langchain_core.documents import Document from langchain_milvus.vectorstores import Milvus @@ -27,6 +28,7 @@ def _milvus_from_texts( metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, drop: bool = True, + **kwargs: Any, ) -> Milvus: return Milvus.from_texts( fake_texts, @@ -36,6 +38,7 @@ def _milvus_from_texts( # connection_args={"uri": "http://127.0.0.1:19530"}, connection_args={"uri": "./milvus_demo.db"}, drop_old=drop, + **kwargs, ) @@ -50,6 +53,15 @@ def test_milvus() -> None: assert_docs_equal_without_pk(output, [Document(page_content="foo")]) +def test_milvus_vector_search() -> None: + """Test end to end construction and search by vector.""" + docsearch = _milvus_from_texts() + output = docsearch.similarity_search_by_vector( + FakeEmbeddings().embed_query("foo"), k=1 + ) + assert_docs_equal_without_pk(output, [Document(page_content="foo")]) + + def test_milvus_with_metadata() -> None: """Test with metadata""" docsearch = _milvus_from_texts(metadatas=[{"label": "test"}] * len(fake_texts)) @@ -110,6 +122,21 @@ def test_milvus_max_marginal_relevance_search() -> None: ) +def test_milvus_max_marginal_relevance_search_with_dynamic_field() -> None: + """Test end to end construction and MRR search with enabling dynamic field.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _milvus_from_texts(metadatas=metadatas, enable_dynamic_field=True) + output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3) + assert_docs_equal_without_pk( + output, + [ + Document(page_content="foo", metadata={"page": 0}), + Document(page_content="baz", metadata={"page": 2}), + ], + ) + + def test_milvus_add_extra() -> None: """Test end to end construction and MRR search.""" texts = ["foo", "bar", "baz"] @@ -123,7 +150,7 @@ def test_milvus_add_extra() -> None: def test_milvus_no_drop() -> None: - """Test end to end construction and MRR search.""" + """Test construction without dropping old data.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] docsearch = _milvus_from_texts(metadatas=metadatas) @@ -171,14 +198,95 @@ def test_milvus_upsert_entities() -> None: assert len(ids) == 2 # type: ignore[arg-type] +def test_milvus_enable_dynamic_field() -> None: + """Test end to end construction and enable dynamic field""" + texts = ["foo", "bar", "baz"] + metadatas = [{"id": i} for i in range(len(texts))] + docsearch = _milvus_from_texts(metadatas=metadatas, enable_dynamic_field=True) + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 3 + + # When enable dynamic field, any new field data will be added to the collection. + new_metadatas = [{"id_new": i} for i in range(len(texts))] + docsearch.add_texts(texts, new_metadatas) + + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + + assert set(docsearch.fields) == { + docsearch._primary_field, + docsearch._text_field, + docsearch._vector_field, + } + + +def test_milvus_disable_dynamic_field() -> None: + """Test end to end construction and disable dynamic field""" + texts = ["foo", "bar", "baz"] + metadatas = [{"id": i} for i in range(len(texts))] + docsearch = _milvus_from_texts(metadatas=metadatas, enable_dynamic_field=False) + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 3 + # ["pk", "text", "vector", "id"] + assert set(docsearch.fields) == { + docsearch._primary_field, + docsearch._text_field, + docsearch._vector_field, + "id", + } + + # Try to add new fields "id_new", but since dynamic field is disabled, + # all fields in the collection is specified as ["pk", "text", "vector", "id"], + # new field information "id_new" will not be added. + new_metadatas = [{"id": i, "id_new": i} for i in range(len(texts))] + docsearch.add_texts(texts, new_metadatas) + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + for doc in output: + assert set(doc.metadata.keys()) == {"id", "pk"} # `id_new` is not added. + + # When disable dynamic field, + # missing data of the created fields "id", will raise an exception. + with pytest.raises(Exception): + new_metadatas = [{"id_new": i} for i in range(len(texts))] + docsearch.add_texts(texts, new_metadatas) + + +def test_milvus_metadata_field() -> None: + """Test end to end construction and use metadata field""" + texts = ["foo", "bar", "baz"] + metadatas = [{"id": i} for i in range(len(texts))] + docsearch = _milvus_from_texts(metadatas=metadatas, metadata_field="metadata") + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 3 + + new_metadatas = [{"id_new": i} for i in range(len(texts))] + docsearch.add_texts(texts, new_metadatas) + + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + + assert set(docsearch.fields) == { + docsearch._primary_field, + docsearch._text_field, + docsearch._vector_field, + docsearch._metadata_field, + } + + # if __name__ == "__main__": # test_milvus() +# test_milvus_vector_search() # test_milvus_with_metadata() # test_milvus_with_id() # test_milvus_with_score() # test_milvus_max_marginal_relevance_search() +# test_milvus_max_marginal_relevance_search_with_dynamic_field() # test_milvus_add_extra() # test_milvus_no_drop() # test_milvus_get_pks() # test_milvus_delete_entities() # test_milvus_upsert_entities() +# test_milvus_enable_dynamic_field() +# test_milvus_disable_dynamic_field() +# test_milvus_metadata_field()