mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(core): validate batch_size in _batch and _abatch to prevent infinite loop (#36663)
This commit is contained in:
@@ -90,6 +90,9 @@ def _hash_nested_dict(
|
|||||||
|
|
||||||
def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
|
def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
|
||||||
"""Utility batching function."""
|
"""Utility batching function."""
|
||||||
|
if size <= 0:
|
||||||
|
msg = f"Batch size must be a positive integer, got {size}."
|
||||||
|
raise ValueError(msg)
|
||||||
it = iter(iterable)
|
it = iter(iterable)
|
||||||
while True:
|
while True:
|
||||||
chunk = list(islice(it, size))
|
chunk = list(islice(it, size))
|
||||||
@@ -100,6 +103,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
|
|||||||
|
|
||||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
|
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
|
||||||
"""Utility batching function."""
|
"""Utility batching function."""
|
||||||
|
if size <= 0:
|
||||||
|
msg = f"Batch size must be a positive integer, got {size}."
|
||||||
|
raise ValueError(msg)
|
||||||
batch: list[T] = []
|
batch: list[T] = []
|
||||||
async for element in iterable:
|
async for element in iterable:
|
||||||
if len(batch) < size:
|
if len(batch) < size:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from langchain_core.indexing import InMemoryRecordManager, aindex, index
|
|||||||
from langchain_core.indexing.api import (
|
from langchain_core.indexing.api import (
|
||||||
IndexingException,
|
IndexingException,
|
||||||
_abatch,
|
_abatch,
|
||||||
|
_batch,
|
||||||
_get_document_with_hash,
|
_get_document_with_hash,
|
||||||
)
|
)
|
||||||
from langchain_core.indexing.in_memory import InMemoryDocumentIndex
|
from langchain_core.indexing.in_memory import InMemoryDocumentIndex
|
||||||
@@ -2433,6 +2434,26 @@ async def test_abatch() -> None:
|
|||||||
assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]]
|
assert [batch async for batch in batches] == [[0, 1], [2, 3], [4]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_validation() -> None:
|
||||||
|
"""Test that _batch raises ValueError for non-positive batch sizes."""
|
||||||
|
with pytest.raises(ValueError, match="Batch size must be a positive integer"):
|
||||||
|
list(_batch(0, [1, 2, 3]))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Batch size must be a positive integer"):
|
||||||
|
list(_batch(-1, [1, 2, 3]))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_abatch_validation() -> None:
|
||||||
|
"""Test that _abatch raises ValueError for non-positive batch sizes."""
|
||||||
|
with pytest.raises(ValueError, match="Batch size must be a positive integer"):
|
||||||
|
async for _ in _abatch(0, _to_async_iter([1, 2, 3])):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Batch size must be a positive integer"):
|
||||||
|
async for _ in _abatch(-1, _to_async_iter([1, 2, 3])):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_indexing_force_update(
|
def test_indexing_force_update(
|
||||||
record_manager: InMemoryRecordManager, upserting_vector_store: VectorStore
|
record_manager: InMemoryRecordManager, upserting_vector_store: VectorStore
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user