mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
parent
acb54d8b9d
commit
3408810748
@ -1,10 +1,12 @@
|
||||
from collections import deque
|
||||
from itertools import islice
|
||||
from typing import (
|
||||
Any,
|
||||
ContextManager,
|
||||
Deque,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
@ -161,3 +163,13 @@ class Tee(Generic[T]):
|
||||
|
||||
# Why this is needed https://stackoverflow.com/a/44638570
|
||||
safetee = Tee
|
||||
|
||||
|
||||
def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||
"""Utility batching function."""
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(islice(it, size))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
0
libs/langchain/tests/unit_tests/utils/__init__.py
Normal file
0
libs/langchain/tests/unit_tests/utils/__init__.py
Normal file
21
libs/langchain/tests/unit_tests/utils/test_iter.py
Normal file
21
libs/langchain/tests/unit_tests/utils/test_iter.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.utils.iter import batch_iterate
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_size, input_iterable, expected_output",
|
||||
[
|
||||
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||
(1, [100, 200, 300], [[100], [200], [300]]),
|
||||
(4, [], []),
|
||||
],
|
||||
)
|
||||
def test_batch_iterate(
|
||||
input_size: int, input_iterable: List[str], expected_output: List[str]
|
||||
) -> None:
|
||||
"""Test batching function."""
|
||||
assert list(batch_iterate(input_size, input_iterable)) == expected_output
|
Loading…
Reference in New Issue
Block a user