fix(core): validate batch_size in _batch and _abatch to prevent infinite loop (#36663)

This commit is contained in:
Sharvil Saxena
2026-04-26 15:13:20 -04:00
committed by GitHub
parent 4613a4d951
commit 78546e9242
2 changed files with 27 additions and 0 deletions

View File

@@ -90,6 +90,9 @@ def _hash_nested_dict(
def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]: def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
"""Utility batching function.""" """Utility batching function."""
if size <= 0:
msg = f"Batch size must be a positive integer, got {size}."
raise ValueError(msg)
it = iter(iterable) it = iter(iterable)
while True: while True:
chunk = list(islice(it, size)) chunk = list(islice(it, size))
@@ -100,6 +103,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]: async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
"""Utility batching function.""" """Utility batching function."""
if size <= 0:
msg = f"Batch size must be a positive integer, got {size}."
raise ValueError(msg)
batch: list[T] = [] batch: list[T] = []
async for element in iterable: async for element in iterable:
if len(batch) < size: if len(batch) < size:

View File

@@ -16,6 +16,7 @@ from langchain_core.indexing import InMemoryRecordManager, aindex, index
from langchain_core.indexing.api import ( from langchain_core.indexing.api import (
IndexingException, IndexingException,
_abatch, _abatch,
_batch,
_get_document_with_hash, _get_document_with_hash,
) )
from langchain_core.indexing.in_memory import InMemoryDocumentIndex from langchain_core.indexing.in_memory import InMemoryDocumentIndex
@@ -2433,6 +2434,26 @@ async def test_abatch() -> None:
assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]] assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]]
def test_batch_validation() -> None:
"""Test that _batch raises ValueError for non-positive batch sizes."""
with pytest.raises(ValueError, match="Batch size must be a positive integer"):
list(_batch(0, [1, 2, 3]))
with pytest.raises(ValueError, match="Batch size must be a positive integer"):
list(_batch(-1, [1, 2, 3]))
async def test_abatch_validation() -> None:
"""Test that _abatch raises ValueError for non-positive batch sizes."""
with pytest.raises(ValueError, match="Batch size must be a positive integer"):
async for _ in _abatch(0, _to_async_iter([1, 2, 3])):
pass
with pytest.raises(ValueError, match="Batch size must be a positive integer"):
async for _ in _abatch(-1, _to_async_iter([1, 2, 3])):
pass
def test_indexing_force_update( def test_indexing_force_update(
record_manager: InMemoryRecordManager, upserting_vector_store: VectorStore record_manager: InMemoryRecordManager, upserting_vector_store: VectorStore
) -> None: ) -> None: