diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py index 8f8755a9644..0430c03c824 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -1,3 +1,5 @@ +from typing import Any, Callable, Type + import pytest from langchain import PromptTemplate @@ -10,30 +12,42 @@ from langchain.schema.runnable import ( ) -@pytest.mark.asyncio -async def test_put_get() -> None: +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.invoke(x), "foo", "foo"), + (lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]), + (lambda r, x: list(r.stream(x))[0], "foo", "foo"), + ], +) +def test_put_get(method: Callable, input: Any, output: Any) -> None: runnable = PutLocalVar("input") | GetLocalVar("input") - assert runnable.invoke("foo") == "foo" - assert runnable.batch(["foo", "bar"]) == ["foo", "bar"] - assert list(runnable.stream("foo"))[0] == "foo" - - assert await runnable.ainvoke("foo") == "foo" - assert await runnable.abatch(["foo", "bar"]) == ["foo", "bar"] - async for x in runnable.astream("foo"): - assert x == "foo" + assert method(runnable, input) == output -def test_missing_config() -> None: - with pytest.raises(ValueError): - PutLocalVar("input").invoke("foo") - - with pytest.raises(ValueError): - GetLocalVar[str, str]("input").invoke("foo") +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.ainvoke(x), "foo", "foo"), + (lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]), + ], +) +async def test_put_get_async(method: Callable, input: Any, output: Any) -> None: + runnable = PutLocalVar("input") | GetLocalVar("input") + assert await method(runnable, input) == output -def test_get_missing_var_invoke() -> None: - runnable = PutLocalVar("input") | GetLocalVar("missing") - with pytest.raises(KeyError): +@pytest.mark.parametrize( + ("runnable", "error"), + [ + (PutLocalVar("input"), ValueError), + (GetLocalVar("input"), ValueError), + (PutLocalVar("input") | GetLocalVar("missing"), KeyError), + ], +) +def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None: + with pytest.raises(error): runnable.invoke("foo") @@ -42,24 +56,24 @@ def test_get_in_map() -> None: assert runnable.invoke("foo") == {"bar": "foo"} -def test_cant_put_in_map() -> None: +def test_put_in_map() -> None: runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input") with pytest.raises(KeyError): runnable.invoke("foo") -def test_get_passthrough_key() -> None: - runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output") - assert runnable.invoke("foo") == {"input": "foo", "output": "foo"} - - -def test_multi_step_sequence() -> None: - prompt = PromptTemplate.from_template("say {foo}") - runnable = ( - PutLocalVar("foo") - | {"foo": RunnablePassthrough()} - | prompt - | FakeListLLM(responses=["bar"]) - | GetLocalVar("foo", passthrough_key="output") - ) - assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"} +@pytest.mark.parametrize( + "runnable", + [ + PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"), + ( + PutLocalVar("input") + | {"input": RunnablePassthrough()} + | PromptTemplate.from_template("say {input}") + | FakeListLLM(responses=["hello"]) + | GetLocalVar("input", passthrough_key="output") + ), + ], +) +def test_put_get_sequence(runnable: RunnableSequence) -> None: + assert runnable.invoke("hello") == {"input": "hello", "output": "hello"}