From e1a6341940e7b62a5a1ecff17f99d56f362d449b Mon Sep 17 00:00:00 2001 From: Zachary Wilkins Date: Mon, 25 Mar 2024 11:58:29 -0400 Subject: [PATCH] langchain: Passthrough batch_size on index()/aindex() calls (#19443) **Description:** This change passes through `batch_size` to `add_documents()`/`aadd_documents()` on calls to `index()` and `aindex()` such that the documents are processed in the expected batch size. **Issue:** #19415 **Dependencies:** N/A **Twitter handle:** N/A --- libs/langchain/langchain/indexes/_api.py | 6 ++- .../tests/unit_tests/indexes/test_indexing.py | 43 ++++++++++++++++++- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index 01371c4b73f..f41a221795a 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -330,7 +330,7 @@ def index( # Be pessimistic and assume that all vector store write will fail. # First write to vector store if docs_to_index: - vector_store.add_documents(docs_to_index, ids=uids) + vector_store.add_documents(docs_to_index, ids=uids, batch_size=batch_size) num_added += len(docs_to_index) - len(seen_docs) num_updated += len(seen_docs) @@ -544,7 +544,9 @@ async def aindex( # Be pessimistic and assume that all vector store write will fail. # First write to vector store if docs_to_index: - await vector_store.aadd_documents(docs_to_index, ids=uids) + await vector_store.aadd_documents( + docs_to_index, ids=uids, batch_size=batch_size + ) num_added += len(docs_to_index) - len(seen_docs) num_updated += len(seen_docs) diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index 73c906a1850..10275db9439 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -20,7 +20,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VST, VectorStore from langchain.indexes import aindex, index -from langchain.indexes._api import _abatch +from langchain.indexes._api import _abatch, _HashedDocument from langchain.indexes._sql_record_manager import SQLRecordManager @@ -1304,3 +1304,44 @@ async def test_aindexing_force_update( "num_skipped": 0, "num_updated": 2, } + + +def test_indexing_custom_batch_size( + record_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with a custom batch size.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + ids = [_HashedDocument.from_document(doc).uid for doc in docs] + + batch_size = 1 + with patch.object(vector_store, "add_documents") as mock_add_documents: + index(docs, record_manager, vector_store, batch_size=batch_size) + args, kwargs = mock_add_documents.call_args + assert args == (docs,) + assert kwargs == {"ids": ids, "batch_size": batch_size} + + +@pytest.mark.requires("aiosqlite") +async def test_aindexing_custom_batch_size( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with a custom batch size.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + ids = [_HashedDocument.from_document(doc).uid for doc in docs] + + batch_size = 1 + with patch.object(vector_store, "aadd_documents") as mock_add_documents: + await aindex(docs, arecord_manager, vector_store, batch_size=batch_size) + args, kwargs = mock_add_documents.call_args + assert args == (docs,) + assert kwargs == {"ids": ids, "batch_size": batch_size}