mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
Add unit test for fix
This commit is contained in:
parent
acffca8bb7
commit
c9682d9532
@ -5452,3 +5452,37 @@ def test_runnable_assign() -> None:
|
|||||||
|
|
||||||
result = runnable_assign.invoke({"input": 5})
|
result = runnable_assign.invoke({"input": 5})
|
||||||
assert result == {"input": 5, "add_step": {"added": 15}}
|
assert result == {"input": 5, "add_step": {"added": 15}}
|
||||||
|
|
||||||
|
def test_runnable_typed_dict_schema() -> None:
|
||||||
|
"""Testing that the schema is generated properly(not empty) when using TypedDict
|
||||||
|
|
||||||
|
subclasses to annotate the arguments of a RunnableParallel children.
|
||||||
|
"""
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
from langchain_core.runnables import RunnableParallel, RunnableLambda
|
||||||
|
|
||||||
|
class Foo(TypedDict):
|
||||||
|
foo: str
|
||||||
|
|
||||||
|
class InputData(Foo):
|
||||||
|
bar: str
|
||||||
|
|
||||||
|
def forward_foo(input_data: InputData):
|
||||||
|
return input_data["foo"]
|
||||||
|
|
||||||
|
def transform_input(input_data: InputData):
|
||||||
|
foo = input_data["foo"]
|
||||||
|
bar = input_data["bar"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"transformed": foo + bar
|
||||||
|
}
|
||||||
|
|
||||||
|
foo_runnable = RunnableLambda(forward_foo)
|
||||||
|
other_runnable = RunnableLambda(transform_input)
|
||||||
|
|
||||||
|
parallel = RunnableParallel(
|
||||||
|
foo=foo_runnable,
|
||||||
|
other=other_runnable,
|
||||||
|
)
|
||||||
|
assert(repr(parallel.input_schema.validate({ "foo": "Y", "bar": "Z" })) == "RunnableParallel<foo,other>Input(root={'foo': 'Y', 'bar': 'Z'})")
|
Loading…
Reference in New Issue
Block a user