Add unit test for fix

This commit is contained in:
William Zhu 2024-11-18 16:45:24 -05:00
parent acffca8bb7
commit c9682d9532

View File

@ -5452,3 +5452,37 @@ def test_runnable_assign() -> None:
result = runnable_assign.invoke({"input": 5})
assert result == {"input": 5, "add_step": {"added": 15}}
def test_runnable_typed_dict_schema() -> None:
"""Testing that the schema is generated properly(not empty) when using TypedDict
subclasses to annotate the arguments of a RunnableParallel children.
"""
from typing_extensions import TypedDict
from langchain_core.runnables import RunnableParallel, RunnableLambda
class Foo(TypedDict):
foo: str
class InputData(Foo):
bar: str
def forward_foo(input_data: InputData):
return input_data["foo"]
def transform_input(input_data: InputData):
foo = input_data["foo"]
bar = input_data["bar"]
return {
"transformed": foo + bar
}
foo_runnable = RunnableLambda(forward_foo)
other_runnable = RunnableLambda(transform_input)
parallel = RunnableParallel(
foo=foo_runnable,
other=other_runnable,
)
assert(repr(parallel.input_schema.validate({ "foo": "Y", "bar": "Z" })) == "RunnableParallel<foo,other>Input(root={'foo': 'Y', 'bar': 'Z'})")