diff --git a/libs/langchain/langchain/indexes/_api.py b/libs/langchain/langchain/indexes/_api.py index 01371c4b73f..f41a221795a 100644 --- a/libs/langchain/langchain/indexes/_api.py +++ b/libs/langchain/langchain/indexes/_api.py @@ -330,7 +330,7 @@ def index( # Be pessimistic and assume that all vector store write will fail. # First write to vector store if docs_to_index: - vector_store.add_documents(docs_to_index, ids=uids) + vector_store.add_documents(docs_to_index, ids=uids, batch_size=batch_size) num_added += len(docs_to_index) - len(seen_docs) num_updated += len(seen_docs) @@ -544,7 +544,9 @@ async def aindex( # Be pessimistic and assume that all vector store write will fail. # First write to vector store if docs_to_index: - await vector_store.aadd_documents(docs_to_index, ids=uids) + await vector_store.aadd_documents( + docs_to_index, ids=uids, batch_size=batch_size + ) num_added += len(docs_to_index) - len(seen_docs) num_updated += len(seen_docs) diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index 73c906a1850..10275db9439 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -20,7 +20,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VST, VectorStore from langchain.indexes import aindex, index -from langchain.indexes._api import _abatch +from langchain.indexes._api import _abatch, _HashedDocument from langchain.indexes._sql_record_manager import SQLRecordManager @@ -1304,3 +1304,44 @@ async def test_aindexing_force_update( "num_skipped": 0, "num_updated": 2, } + + +def test_indexing_custom_batch_size( + record_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with a custom batch size.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + ids = [_HashedDocument.from_document(doc).uid for doc in docs] + + batch_size = 1 + with patch.object(vector_store, "add_documents") as mock_add_documents: + index(docs, record_manager, vector_store, batch_size=batch_size) + args, kwargs = mock_add_documents.call_args + assert args == (docs,) + assert kwargs == {"ids": ids, "batch_size": batch_size} + + +@pytest.mark.requires("aiosqlite") +async def test_aindexing_custom_batch_size( + arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore +) -> None: + """Test indexing with a custom batch size.""" + docs = [ + Document( + page_content="This is a test document.", + metadata={"source": "1"}, + ), + ] + ids = [_HashedDocument.from_document(doc).uid for doc in docs] + + batch_size = 1 + with patch.object(vector_store, "aadd_documents") as mock_add_documents: + await aindex(docs, arecord_manager, vector_store, batch_size=batch_size) + args, kwargs = mock_add_documents.call_args + assert args == (docs,) + assert kwargs == {"ids": ids, "batch_size": batch_size}