mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
param
This commit is contained in:
parent
354c42afd2
commit
182b059bf4
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any, Callable, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain import PromptTemplate
|
from langchain import PromptTemplate
|
||||||
@ -10,30 +12,42 @@ from langchain.schema.runnable import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.parametrize(
|
||||||
async def test_put_get() -> None:
|
("method", "input", "output"),
|
||||||
|
[
|
||||||
|
(lambda r, x: r.invoke(x), "foo", "foo"),
|
||||||
|
(lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]),
|
||||||
|
(lambda r, x: list(r.stream(x))[0], "foo", "foo"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_put_get(method: Callable, input: Any, output: Any) -> None:
|
||||||
runnable = PutLocalVar("input") | GetLocalVar("input")
|
runnable = PutLocalVar("input") | GetLocalVar("input")
|
||||||
assert runnable.invoke("foo") == "foo"
|
assert method(runnable, input) == output
|
||||||
assert runnable.batch(["foo", "bar"]) == ["foo", "bar"]
|
|
||||||
assert list(runnable.stream("foo"))[0] == "foo"
|
|
||||||
|
|
||||||
assert await runnable.ainvoke("foo") == "foo"
|
|
||||||
assert await runnable.abatch(["foo", "bar"]) == ["foo", "bar"]
|
|
||||||
async for x in runnable.astream("foo"):
|
|
||||||
assert x == "foo"
|
|
||||||
|
|
||||||
|
|
||||||
def test_missing_config() -> None:
|
@pytest.mark.asyncio
|
||||||
with pytest.raises(ValueError):
|
@pytest.mark.parametrize(
|
||||||
PutLocalVar("input").invoke("foo")
|
("method", "input", "output"),
|
||||||
|
[
|
||||||
with pytest.raises(ValueError):
|
(lambda r, x: r.ainvoke(x), "foo", "foo"),
|
||||||
GetLocalVar[str, str]("input").invoke("foo")
|
(lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_put_get_async(method: Callable, input: Any, output: Any) -> None:
|
||||||
|
runnable = PutLocalVar("input") | GetLocalVar("input")
|
||||||
|
assert await method(runnable, input) == output
|
||||||
|
|
||||||
|
|
||||||
def test_get_missing_var_invoke() -> None:
|
@pytest.mark.parametrize(
|
||||||
runnable = PutLocalVar("input") | GetLocalVar("missing")
|
("runnable", "error"),
|
||||||
with pytest.raises(KeyError):
|
[
|
||||||
|
(PutLocalVar("input"), ValueError),
|
||||||
|
(GetLocalVar("input"), ValueError),
|
||||||
|
(PutLocalVar("input") | GetLocalVar("missing"), KeyError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None:
|
||||||
|
with pytest.raises(error):
|
||||||
runnable.invoke("foo")
|
runnable.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
@ -42,24 +56,24 @@ def test_get_in_map() -> None:
|
|||||||
assert runnable.invoke("foo") == {"bar": "foo"}
|
assert runnable.invoke("foo") == {"bar": "foo"}
|
||||||
|
|
||||||
|
|
||||||
def test_cant_put_in_map() -> None:
|
def test_put_in_map() -> None:
|
||||||
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
runnable.invoke("foo")
|
runnable.invoke("foo")
|
||||||
|
|
||||||
|
|
||||||
def test_get_passthrough_key() -> None:
|
@pytest.mark.parametrize(
|
||||||
runnable = PutLocalVar("input") | GetLocalVar("input", passthrough_key="output")
|
"runnable",
|
||||||
assert runnable.invoke("foo") == {"input": "foo", "output": "foo"}
|
[
|
||||||
|
PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"),
|
||||||
|
(
|
||||||
def test_multi_step_sequence() -> None:
|
PutLocalVar("input")
|
||||||
prompt = PromptTemplate.from_template("say {foo}")
|
| {"input": RunnablePassthrough()}
|
||||||
runnable = (
|
| PromptTemplate.from_template("say {input}")
|
||||||
PutLocalVar("foo")
|
| FakeListLLM(responses=["hello"])
|
||||||
| {"foo": RunnablePassthrough()}
|
| GetLocalVar("input", passthrough_key="output")
|
||||||
| prompt
|
),
|
||||||
| FakeListLLM(responses=["bar"])
|
],
|
||||||
| GetLocalVar("foo", passthrough_key="output")
|
)
|
||||||
)
|
def test_put_get_sequence(runnable: RunnableSequence) -> None:
|
||||||
assert runnable.invoke("hello") == {"foo": "hello", "output": "bar"}
|
assert runnable.invoke("hello") == {"input": "hello", "output": "hello"}
|
||||||
|
Loading…
Reference in New Issue
Block a user