This commit is contained in:
Bagatur 2023-08-21 17:31:38 -07:00
parent 354c42afd2
commit 182b059bf4

View File

@ -1,3 +1,5 @@
from typing import Any, Callable, Type
import pytest import pytest
from langchain import PromptTemplate from langchain import PromptTemplate
@ -10,30 +12,42 @@ from langchain.schema.runnable import (
) )
@pytest.mark.asyncio @pytest.mark.parametrize(
async def test_put_get() -> None: ("method", "input", "output"),
[
(lambda r, x: r.invoke(x), "foo", "foo"),
(lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]),
(lambda r, x: list(r.stream(x))[0], "foo", "foo"),
],
)
def test_put_get(method: Callable, input: Any, output: Any) -> None:
runnable = PutLocalVar("input") | GetLocalVar("input") runnable = PutLocalVar("input") | GetLocalVar("input")
assert runnable.invoke("foo") == "foo" assert method(runnable, input) == output
assert runnable.batch(["foo", "bar"]) == ["foo", "bar"]
assert list(runnable.stream("foo"))[0] == "foo"
assert await runnable.ainvoke("foo") == "foo"
assert await runnable.abatch(["foo", "bar"]) == ["foo", "bar"]
async for x in runnable.astream("foo"):
assert x == "foo"
def test_missing_config() -> None: @pytest.mark.asyncio
with pytest.raises(ValueError): @pytest.mark.parametrize(
PutLocalVar("input").invoke("foo") ("method", "input", "output"),
[
with pytest.raises(ValueError): (lambda r, x: r.ainvoke(x), "foo", "foo"),
GetLocalVar[str, str]("input").invoke("foo") (lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]),
],
)
async def test_put_get_async(method: Callable, input: Any, output: Any) -> None:
runnable = PutLocalVar("input") | GetLocalVar("input")
assert await method(runnable, input) == output
def test_get_missing_var_invoke() -> None: @pytest.mark.parametrize(
runnable = PutLocalVar("input") | GetLocalVar("missing") ("runnable", "error"),
with pytest.raises(KeyError): [
(PutLocalVar("input"), ValueError),
(GetLocalVar("input"), ValueError),
(PutLocalVar("input") | GetLocalVar("missing"), KeyError),
],
)
def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None:
with pytest.raises(error):
runnable.invoke("foo") runnable.invoke("foo")
@ -42,24 +56,24 @@ def test_get_in_map() -> None:
assert runnable.invoke("foo") == {"bar": "foo"} assert runnable.invoke("foo") == {"bar": "foo"}
def test_cant_put_in_map() -> None: def test_put_in_map() -> None:
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input") runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
with pytest.raises(KeyError): with pytest.raises(KeyError):
runnable.invoke("foo") runnable.invoke("foo")
def test_get_passthrough_key() -> None: @pytest.mark.parametrize(
runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output") "runnable",
assert runnable.invoke("foo") == {"input": "foo", "output": "foo"} [
PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"),
(
def test_multi_step_sequence() -> None: PutLocalVar("input")
prompt = PromptTemplate.from_template("say {foo}") | {"input": RunnablePassthrough()}
runnable = ( | PromptTemplate.from_template("say {input}")
PutLocalVar("foo") | FakeListLLM(responses=["hello"])
| {"foo": RunnablePassthrough()} | GetLocalVar("input", passthrough_key="output")
| prompt ),
| FakeListLLM(responses=["bar"]) ],
| GetLocalVar("foo", passthrough_key="output") )
) def test_put_get_sequence(runnable: RunnableSequence) -> None:
assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"} assert runnable.invoke("hello") == {"input": "hello", "output": "hello"}