diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py index 0430c03c824..ee07c0cfc6e 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -75,5 +75,19 @@ def test_put_in_map() -> None: ), ], ) -def test_put_get_sequence(runnable: RunnableSequence) -> None: - assert runnable.invoke("hello") == {"input": "hello", "output": "hello"} +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}), + (lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]), + ( + lambda r, x: list(r.stream(x))[0], + "hello", + {"input": "hello", "output": "hello"}, + ), + ], +) +def test_put_get_sequence( + runnable: RunnableSequence, method: Callable, input: Any, output: Any +) -> None: + assert method(runnable, input) == output