diff --git a/libs/langchain/langchain/utils/iter.py b/libs/langchain/langchain/utils/iter.py index 1b95f180ea4..60834163c3f 100644 --- a/libs/langchain/langchain/utils/iter.py +++ b/libs/langchain/langchain/utils/iter.py @@ -1,10 +1,12 @@ from collections import deque +from itertools import islice from typing import ( Any, ContextManager, Deque, Generator, Generic, + Iterable, Iterator, List, Optional, @@ -161,3 +163,13 @@ class Tee(Generic[T]): # Why this is needed https://stackoverflow.com/a/44638570 safetee = Tee + + +def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]: + """Utility batching function.""" + it = iter(iterable) + while True: + chunk = list(islice(it, size)) + if not chunk: + return + yield chunk diff --git a/libs/langchain/tests/unit_tests/utils/__init__.py b/libs/langchain/tests/unit_tests/utils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain/tests/unit_tests/utils/test_iter.py b/libs/langchain/tests/unit_tests/utils/test_iter.py new file mode 100644 index 00000000000..f0fd8bf4ce5 --- /dev/null +++ b/libs/langchain/tests/unit_tests/utils/test_iter.py @@ -0,0 +1,21 @@ +from typing import List + +import pytest + +from langchain.utils.iter import batch_iterate + + +@pytest.mark.parametrize( + "input_size, input_iterable, expected_output", + [ + (2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]), + (3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]), + (1, [100, 200, 300], [[100], [200], [300]]), + (4, [], []), + ], +) +def test_batch_iterate( + input_size: int, input_iterable: List[str], expected_output: List[str] +) -> None: + """Test batching function.""" + assert list(batch_iterate(input_size, input_iterable)) == expected_output