Add batch util (#9620)

Add `batch` utility to langchain
This commit is contained in:
Eugene Yurtsev 2023-08-22 12:31:18 -04:00 committed by GitHub
parent acb54d8b9d
commit 3408810748
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 0 deletions

View File

@ -1,10 +1,12 @@
from collections import deque
from itertools import islice
from typing import (
Any,
ContextManager,
Deque,
Generator,
Generic,
Iterable,
Iterator,
List,
Optional,
@ -161,3 +163,13 @@ class Tee(Generic[T]):
# Why this is needed https://stackoverflow.com/a/44638570
safetee = Tee
def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
"""Utility batching function."""
it = iter(iterable)
while True:
chunk = list(islice(it, size))
if not chunk:
return
yield chunk

View File

@ -0,0 +1,21 @@
from typing import List
import pytest
from langchain.utils.iter import batch_iterate
@pytest.mark.parametrize(
"input_size, input_iterable, expected_output",
[
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
(1, [100, 200, 300], [[100], [200], [300]]),
(4, [], []),
],
)
def test_batch_iterate(
input_size: int, input_iterable: List[str], expected_output: List[str]
) -> None:
"""Test batching function."""
assert list(batch_iterate(input_size, input_iterable)) == expected_output