From 78546e9242f4eefb72cbc85f023da885c29d1441 Mon Sep 17 00:00:00 2001 From: Sharvil Saxena <97365286+sharziki@users.noreply.github.com> Date: Sun, 26 Apr 2026 15:13:20 -0400 Subject: [PATCH] fix(core): validate batch_size in _batch and _abatch to prevent infinite loop (#36663) --- libs/core/langchain_core/indexing/api.py | 6 ++++++ .../unit_tests/indexing/test_indexing.py | 21 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/libs/core/langchain_core/indexing/api.py b/libs/core/langchain_core/indexing/api.py index da3c54fd63f..b4af08b8b54 100644 --- a/libs/core/langchain_core/indexing/api.py +++ b/libs/core/langchain_core/indexing/api.py @@ -90,6 +90,9 @@ def _hash_nested_dict( def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]: """Utility batching function.""" + if size <= 0: + msg = f"Batch size must be a positive integer, got {size}." + raise ValueError(msg) it = iter(iterable) while True: chunk = list(islice(it, size)) @@ -100,6 +103,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]: async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]: """Utility batching function.""" + if size <= 0: + msg = f"Batch size must be a positive integer, got {size}." + raise ValueError(msg) batch: list[T] = [] async for element in iterable: if len(batch) < size: diff --git a/libs/core/tests/unit_tests/indexing/test_indexing.py b/libs/core/tests/unit_tests/indexing/test_indexing.py index f598048a98f..f89786e17d7 100644 --- a/libs/core/tests/unit_tests/indexing/test_indexing.py +++ b/libs/core/tests/unit_tests/indexing/test_indexing.py @@ -16,6 +16,7 @@ from langchain_core.indexing import InMemoryRecordManager, aindex, index from langchain_core.indexing.api import ( IndexingException, _abatch, + _batch, _get_document_with_hash, ) from langchain_core.indexing.in_memory import InMemoryDocumentIndex @@ -2433,6 +2434,26 @@ async def test_abatch() -> None: assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]] +def test_batch_validation() -> None: + """Test that _batch raises ValueError for non-positive batch sizes.""" + with pytest.raises(ValueError, match="Batch size must be a positive integer"): + list(_batch(0, [1, 2, 3])) + + with pytest.raises(ValueError, match="Batch size must be a positive integer"): + list(_batch(-1, [1, 2, 3])) + + +async def test_abatch_validation() -> None: + """Test that _abatch raises ValueError for non-positive batch sizes.""" + with pytest.raises(ValueError, match="Batch size must be a positive integer"): + async for _ in _abatch(0, _to_async_iter([1, 2, 3])): + pass + + with pytest.raises(ValueError, match="Batch size must be a positive integer"): + async for _ in _abatch(-1, _to_async_iter([1, 2, 3])): + pass + + def test_indexing_force_update( record_manager: InMemoryRecordManager, upserting_vector_store: VectorStore ) -> None: