mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(core): validate batch_size in _batch and _abatch to prevent infinite loop (#36663)
This commit is contained in:
@@ -90,6 +90,9 @@ def _hash_nested_dict(
|
||||
|
||||
def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
|
||||
"""Utility batching function."""
|
||||
if size <= 0:
|
||||
msg = f"Batch size must be a positive integer, got {size}."
|
||||
raise ValueError(msg)
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(islice(it, size))
|
||||
@@ -100,6 +103,9 @@ def _batch(size: int, iterable: Iterable[T]) -> Iterator[list[T]]:
|
||||
|
||||
async def _abatch(size: int, iterable: AsyncIterable[T]) -> AsyncIterator[list[T]]:
|
||||
"""Utility batching function."""
|
||||
if size <= 0:
|
||||
msg = f"Batch size must be a positive integer, got {size}."
|
||||
raise ValueError(msg)
|
||||
batch: list[T] = []
|
||||
async for element in iterable:
|
||||
if len(batch) < size:
|
||||
|
||||
Reference in New Issue
Block a user