This commit is contained in:
Bagatur 2023-08-21 17:31:38 -07:00
parent 354c42afd2
commit 182b059bf4

View File

@ -1,3 +1,5 @@
from typing import Any, Callable, Type
import pytest
from langchain import PromptTemplate
@ -10,30 +12,42 @@ from langchain.schema.runnable import (
)
@pytest.mark.asyncio
async def test_put_get() -> None:
@pytest.mark.parametrize(
("method", "input", "output"),
[
(lambda r, x: r.invoke(x), "foo", "foo"),
(lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]),
(lambda r, x: list(r.stream(x))[0], "foo", "foo"),
],
)
def test_put_get(method: Callable, input: Any, output: Any) -> None:
runnable = PutLocalVar("input") | GetLocalVar("input")
assert runnable.invoke("foo") == "foo"
assert runnable.batch(["foo", "bar"]) == ["foo", "bar"]
assert list(runnable.stream("foo"))[0] == "foo"
assert await runnable.ainvoke("foo") == "foo"
assert await runnable.abatch(["foo", "bar"]) == ["foo", "bar"]
async for x in runnable.astream("foo"):
assert x == "foo"
assert method(runnable, input) == output
def test_missing_config() -> None:
with pytest.raises(ValueError):
PutLocalVar("input").invoke("foo")
with pytest.raises(ValueError):
GetLocalVar[str, str]("input").invoke("foo")
@pytest.mark.asyncio
@pytest.mark.parametrize(
("method", "input", "output"),
[
(lambda r, x: r.ainvoke(x), "foo", "foo"),
(lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]),
],
)
async def test_put_get_async(method: Callable, input: Any, output: Any) -> None:
runnable = PutLocalVar("input") | GetLocalVar("input")
assert await method(runnable, input) == output
def test_get_missing_var_invoke() -> None:
runnable = PutLocalVar("input") | GetLocalVar("missing")
with pytest.raises(KeyError):
@pytest.mark.parametrize(
("runnable", "error"),
[
(PutLocalVar("input"), ValueError),
(GetLocalVar("input"), ValueError),
(PutLocalVar("input") | GetLocalVar("missing"), KeyError),
],
)
def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None:
with pytest.raises(error):
runnable.invoke("foo")
@ -42,24 +56,24 @@ def test_get_in_map() -> None:
assert runnable.invoke("foo") == {"bar": "foo"}
def test_cant_put_in_map() -> None:
def test_put_in_map() -> None:
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
with pytest.raises(KeyError):
runnable.invoke("foo")
def test_get_passthrough_key() -> None:
runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output")
assert runnable.invoke("foo") == {"input": "foo", "output": "foo"}
def test_multi_step_sequence() -> None:
prompt = PromptTemplate.from_template("say {foo}")
runnable = (
PutLocalVar("foo")
| {"foo": RunnablePassthrough()}
| prompt
| FakeListLLM(responses=["bar"])
| GetLocalVar("foo", passthrough_key="output")
)
assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"}
@pytest.mark.parametrize(
"runnable",
[
PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"),
(
PutLocalVar("input")
| {"input": RunnablePassthrough()}
| PromptTemplate.from_template("say {input}")
| FakeListLLM(responses=["hello"])
| GetLocalVar("input", passthrough_key="output")
),
],
)
def test_put_get_sequence(runnable: RunnableSequence) -> None:
assert runnable.invoke("hello") == {"input": "hello", "output": "hello"}