From 780ce00deac87340673f338b51c80e5c1558e97a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Carlos=20Ferra=20de=20Almeida?= Date: Mon, 7 Oct 2024 19:52:50 +0100 Subject: [PATCH] core[minor]: add **kwargs to index and aindex functions for custom vector_field support (#26998) Added `**kwargs` parameters to the `index` and `aindex` functions in `libs/core/langchain_core/indexing/api.py`. This allows users to pass additional arguments to the `add_documents` and `aadd_documents` methods, enabling the specification of a custom `vector_field`. For example, users can now use `vector_field="embedding"` when indexing documents in `OpenSearchVectorStore` --------- Co-authored-by: Eugene Yurtsev --- libs/core/langchain_core/indexing/api.py | 34 +++- .../unit_tests/indexing/test_indexing.py | 178 ++++++++++++++++++ 2 files changed, 208 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/indexing/api.py b/libs/core/langchain_core/indexing/api.py index 814356b17c3..26566b1be80 100644 --- a/libs/core/langchain_core/indexing/api.py +++ b/libs/core/langchain_core/indexing/api.py @@ -198,6 +198,7 @@ def index( source_id_key: Union[str, Callable[[Document], str], None] = None, cleanup_batch_size: int = 1_000, force_update: bool = False, + upsert_kwargs: Optional[dict[str, Any]] = None, ) -> IndexingResult: """Index data from the loader into the vector store. @@ -249,6 +250,12 @@ def index( force_update: Force update documents even if they are present in the record manager. Useful if you are re-indexing with updated embeddings. Default is False. + upsert_kwargs: Additional keyword arguments to pass to the add_documents + method of the VectorStore or the upsert method of the + DocumentIndex. For example, you can use this to + specify a custom vector_field: + upsert_kwargs={"vector_field": "embedding"} + .. versionadded:: 0.3.10 Returns: Indexing result which contains information about how many documents @@ -363,10 +370,16 @@ def index( if docs_to_index: if isinstance(destination, VectorStore): destination.add_documents( - docs_to_index, ids=uids, batch_size=batch_size + docs_to_index, + ids=uids, + batch_size=batch_size, + **(upsert_kwargs or {}), ) elif isinstance(destination, DocumentIndex): - destination.upsert(docs_to_index) + destination.upsert( + docs_to_index, + **(upsert_kwargs or {}), + ) num_added += len(docs_to_index) - len(seen_docs) num_updated += len(seen_docs) @@ -438,6 +451,7 @@ async def aindex( source_id_key: Union[str, Callable[[Document], str], None] = None, cleanup_batch_size: int = 1_000, force_update: bool = False, + upsert_kwargs: Optional[dict[str, Any]] = None, ) -> IndexingResult: """Async index data from the loader into the vector store. @@ -480,6 +494,12 @@ async def aindex( force_update: Force update documents even if they are present in the record manager. Useful if you are re-indexing with updated embeddings. Default is False. + upsert_kwargs: Additional keyword arguments to pass to the aadd_documents + method of the VectorStore or the aupsert method of the + DocumentIndex. For example, you can use this to + specify a custom vector_field: + upsert_kwargs={"vector_field": "embedding"} + .. versionadded:: 0.3.10 Returns: Indexing result which contains information about how many documents @@ -604,10 +624,16 @@ async def aindex( if docs_to_index: if isinstance(destination, VectorStore): await destination.aadd_documents( - docs_to_index, ids=uids, batch_size=batch_size + docs_to_index, + ids=uids, + batch_size=batch_size, + **(upsert_kwargs or {}), ) elif isinstance(destination, DocumentIndex): - await destination.aupsert(docs_to_index) + await destination.aupsert( + docs_to_index, + **(upsert_kwargs or {}), + ) num_added += len(docs_to_index) - len(seen_docs) num_updated += len(seen_docs) diff --git a/libs/core/tests/unit_tests/indexing/test_indexing.py b/libs/core/tests/unit_tests/indexing/test_indexing.py index 96d3584dad8..287b6b49f66 100644 --- a/libs/core/tests/unit_tests/indexing/test_indexing.py +++ b/libs/core/tests/unit_tests/indexing/test_indexing.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio +from pytest_mock import MockerFixture from langchain_core.document_loaders.base import BaseLoader from langchain_core.documents import Document @@ -1728,3 +1729,180 @@ async def test_incremental_aindexing_with_batch_size_with_optimization( for uid in vector_store.store } assert doc_texts == {"updated 1", "2", "3", "updated 4"} + + +def test_index_with_upsert_kwargs( + record_manager: InMemoryRecordManager, upserting_vector_store: InMemoryVectorStore +) -> None: + """Test indexing with upsert_kwargs parameter.""" + mock_add_documents = MagicMock() + + with patch.object(upserting_vector_store, "add_documents", mock_add_documents): + docs = [ + Document( + page_content="Test document 1", + metadata={"source": "1"}, + ), + Document( + page_content="Test document 2", + metadata={"source": "2"}, + ), + ] + + upsert_kwargs = {"vector_field": "embedding"} + + index(docs, record_manager, upserting_vector_store, upsert_kwargs=upsert_kwargs) + + # Assert that add_documents was called with the correct arguments + mock_add_documents.assert_called_once() + call_args = mock_add_documents.call_args + assert call_args is not None + args, kwargs = call_args + + # Check that the documents are correct (ignoring ids) + assert len(args[0]) == 2 + assert all(isinstance(doc, Document) for doc in args[0]) + assert [doc.page_content for doc in args[0]] == [ + "Test document 1", + "Test document 2", + ] + assert [doc.metadata for doc in args[0]] == [{"source": "1"}, {"source": "2"}] + + # Check that ids are present + assert "ids" in kwargs + assert isinstance(kwargs["ids"], list) + assert len(kwargs["ids"]) == 2 + + # Check other arguments + assert kwargs["batch_size"] == 100 + assert kwargs["vector_field"] == "embedding" + + +def test_index_with_upsert_kwargs_for_document_indexer( + record_manager: InMemoryRecordManager, + mocker: MockerFixture, +) -> None: + """Test that kwargs are passed to the upsert method of the document indexer.""" + + document_index = InMemoryDocumentIndex() + upsert_spy = mocker.spy(document_index.__class__, "upsert") + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + + upsert_kwargs = {"vector_field": "embedding"} + + assert index( + docs, + record_manager, + document_index, + cleanup="full", + upsert_kwargs=upsert_kwargs, + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + assert upsert_spy.call_count == 1 + # assert call kwargs were passed as kwargs + assert upsert_spy.call_args.kwargs == upsert_kwargs + + +async def test_aindex_with_upsert_kwargs_for_document_indexer( + arecord_manager: InMemoryRecordManager, + mocker: MockerFixture, +) -> None: + """Test that kwargs are passed to the upsert method of the document indexer.""" + + document_index = InMemoryDocumentIndex() + upsert_spy = mocker.spy(document_index.__class__, "aupsert") + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + Document( + page_content="This is another document.", + metadata={"source": "2"}, + ), + ] + + upsert_kwargs = {"vector_field": "embedding"} + + assert await aindex( + docs, + arecord_manager, + document_index, + cleanup="full", + upsert_kwargs=upsert_kwargs, + ) == { + "num_added": 2, + "num_deleted": 0, + "num_skipped": 0, + "num_updated": 0, + } + + assert upsert_spy.call_count == 1 + # assert call kwargs were passed as kwargs + assert upsert_spy.call_args.kwargs == upsert_kwargs + + +async def test_aindex_with_upsert_kwargs( + arecord_manager: InMemoryRecordManager, upserting_vector_store: InMemoryVectorStore +) -> None: + """Test async indexing with upsert_kwargs parameter.""" + mock_aadd_documents = AsyncMock() + + with patch.object(upserting_vector_store, "aadd_documents", mock_aadd_documents): + docs = [ + Document( + page_content="Async test document 1", + metadata={"source": "1"}, + ), + Document( + page_content="Async test document 2", + metadata={"source": "2"}, + ), + ] + + upsert_kwargs = {"vector_field": "embedding"} + + await aindex( + docs, + arecord_manager, + upserting_vector_store, + upsert_kwargs=upsert_kwargs, + ) + + # Assert that aadd_documents was called with the correct arguments + mock_aadd_documents.assert_called_once() + call_args = mock_aadd_documents.call_args + assert call_args is not None + args, kwargs = call_args + + # Check that the documents are correct (ignoring ids) + assert len(args[0]) == 2 + assert all(isinstance(doc, Document) for doc in args[0]) + assert [doc.page_content for doc in args[0]] == [ + "Async test document 1", + "Async test document 2", + ] + assert [doc.metadata for doc in args[0]] == [{"source": "1"}, {"source": "2"}] + + # Check that ids are present + assert "ids" in kwargs + assert isinstance(kwargs["ids"], list) + assert len(kwargs["ids"]) == 2 + + # Check other arguments + assert kwargs["batch_size"] == 100 + assert kwargs["vector_field"] == "embedding"