mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
parent
acb54d8b9d
commit
3408810748
@ -1,10 +1,12 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
|
from itertools import islice
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
Deque,
|
Deque,
|
||||||
Generator,
|
Generator,
|
||||||
Generic,
|
Generic,
|
||||||
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -161,3 +163,13 @@ class Tee(Generic[T]):
|
|||||||
|
|
||||||
# Why this is needed https://stackoverflow.com/a/44638570
|
# Why this is needed https://stackoverflow.com/a/44638570
|
||||||
safetee = Tee
|
safetee = Tee
|
||||||
|
|
||||||
|
|
||||||
|
def batch_iterate(size: int, iterable: Iterable[T]) -> Iterator[List[T]]:
|
||||||
|
"""Utility batching function."""
|
||||||
|
it = iter(iterable)
|
||||||
|
while True:
|
||||||
|
chunk = list(islice(it, size))
|
||||||
|
if not chunk:
|
||||||
|
return
|
||||||
|
yield chunk
|
||||||
|
0
libs/langchain/tests/unit_tests/utils/__init__.py
Normal file
0
libs/langchain/tests/unit_tests/utils/__init__.py
Normal file
21
libs/langchain/tests/unit_tests/utils/test_iter.py
Normal file
21
libs/langchain/tests/unit_tests/utils/test_iter.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.utils.iter import batch_iterate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_size, input_iterable, expected_output",
|
||||||
|
[
|
||||||
|
(2, [1, 2, 3, 4, 5], [[1, 2], [3, 4], [5]]),
|
||||||
|
(3, [10, 20, 30, 40, 50], [[10, 20, 30], [40, 50]]),
|
||||||
|
(1, [100, 200, 300], [[100], [200], [300]]),
|
||||||
|
(4, [], []),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_iterate(
|
||||||
|
input_size: int, input_iterable: List[str], expected_output: List[str]
|
||||||
|
) -> None:
|
||||||
|
"""Test batching function."""
|
||||||
|
assert list(batch_iterate(input_size, input_iterable)) == expected_output
|
Loading…
Reference in New Issue
Block a user