mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
param
This commit is contained in:
parent
354c42afd2
commit
182b059bf4
@ -1,3 +1,5 @@
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain import PromptTemplate
|
||||
@ -10,30 +12,42 @@ from langchain.schema.runnable import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_get() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
("method", "input", "output"),
|
||||
[
|
||||
(lambda r, x: r.invoke(x), "foo", "foo"),
|
||||
(lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]),
|
||||
(lambda r, x: list(r.stream(x))[0], "foo", "foo"),
|
||||
],
|
||||
)
|
||||
def test_put_get(method: Callable, input: Any, output: Any) -> None:
|
||||
runnable = PutLocalVar("input") | GetLocalVar("input")
|
||||
assert runnable.invoke("foo") == "foo"
|
||||
assert runnable.batch(["foo", "bar"]) == ["foo", "bar"]
|
||||
assert list(runnable.stream("foo"))[0] == "foo"
|
||||
|
||||
assert await runnable.ainvoke("foo") == "foo"
|
||||
assert await runnable.abatch(["foo", "bar"]) == ["foo", "bar"]
|
||||
async for x in runnable.astream("foo"):
|
||||
assert x == "foo"
|
||||
assert method(runnable, input) == output
|
||||
|
||||
|
||||
def test_missing_config() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
PutLocalVar("input").invoke("foo")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
GetLocalVar[str, str]("input").invoke("foo")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("method", "input", "output"),
|
||||
[
|
||||
(lambda r, x: r.ainvoke(x), "foo", "foo"),
|
||||
(lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]),
|
||||
],
|
||||
)
|
||||
async def test_put_get_async(method: Callable, input: Any, output: Any) -> None:
|
||||
runnable = PutLocalVar("input") | GetLocalVar("input")
|
||||
assert await method(runnable, input) == output
|
||||
|
||||
|
||||
def test_get_missing_var_invoke() -> None:
|
||||
runnable = PutLocalVar("input") | GetLocalVar("missing")
|
||||
with pytest.raises(KeyError):
|
||||
@pytest.mark.parametrize(
|
||||
("runnable", "error"),
|
||||
[
|
||||
(PutLocalVar("input"), ValueError),
|
||||
(GetLocalVar("input"), ValueError),
|
||||
(PutLocalVar("input") | GetLocalVar("missing"), KeyError),
|
||||
],
|
||||
)
|
||||
def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None:
|
||||
with pytest.raises(error):
|
||||
runnable.invoke("foo")
|
||||
|
||||
|
||||
@ -42,24 +56,24 @@ def test_get_in_map() -> None:
|
||||
assert runnable.invoke("foo") == {"bar": "foo"}
|
||||
|
||||
|
||||
def test_cant_put_in_map() -> None:
|
||||
def test_put_in_map() -> None:
|
||||
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
||||
with pytest.raises(KeyError):
|
||||
runnable.invoke("foo")
|
||||
|
||||
|
||||
def test_get_passthrough_key() -> None:
|
||||
runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output")
|
||||
assert runnable.invoke("foo") == {"input": "foo", "output": "foo"}
|
||||
|
||||
|
||||
def test_multi_step_sequence() -> None:
|
||||
prompt = PromptTemplate.from_template("say {foo}")
|
||||
runnable = (
|
||||
PutLocalVar("foo")
|
||||
| {"foo": RunnablePassthrough()}
|
||||
| prompt
|
||||
| FakeListLLM(responses=["bar"])
|
||||
| GetLocalVar("foo", passthrough_key="output")
|
||||
)
|
||||
assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"}
|
||||
@pytest.mark.parametrize(
|
||||
"runnable",
|
||||
[
|
||||
PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"),
|
||||
(
|
||||
PutLocalVar("input")
|
||||
| {"input": RunnablePassthrough()}
|
||||
| PromptTemplate.from_template("say {input}")
|
||||
| FakeListLLM(responses=["hello"])
|
||||
| GetLocalVar("input", passthrough_key="output")
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_put_get_sequence(runnable: RunnableSequence) -> None:
|
||||
assert runnable.invoke("hello") == {"input": "hello", "output": "hello"}
|
||||
|
Loading…
Reference in New Issue
Block a user