mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 08:06:14 +00:00
Add unit test for fix
This commit is contained in:
parent
acffca8bb7
commit
c9682d9532
@ -5452,3 +5452,37 @@ def test_runnable_assign() -> None:
|
||||
|
||||
result = runnable_assign.invoke({"input": 5})
|
||||
assert result == {"input": 5, "add_step": {"added": 15}}
|
||||
|
||||
def test_runnable_typed_dict_schema() -> None:
|
||||
"""Testing that the schema is generated properly(not empty) when using TypedDict
|
||||
|
||||
subclasses to annotate the arguments of a RunnableParallel children.
|
||||
"""
|
||||
from typing_extensions import TypedDict
|
||||
from langchain_core.runnables import RunnableParallel, RunnableLambda
|
||||
|
||||
class Foo(TypedDict):
|
||||
foo: str
|
||||
|
||||
class InputData(Foo):
|
||||
bar: str
|
||||
|
||||
def forward_foo(input_data: InputData):
|
||||
return input_data["foo"]
|
||||
|
||||
def transform_input(input_data: InputData):
|
||||
foo = input_data["foo"]
|
||||
bar = input_data["bar"]
|
||||
|
||||
return {
|
||||
"transformed": foo + bar
|
||||
}
|
||||
|
||||
foo_runnable = RunnableLambda(forward_foo)
|
||||
other_runnable = RunnableLambda(transform_input)
|
||||
|
||||
parallel = RunnableParallel(
|
||||
foo=foo_runnable,
|
||||
other=other_runnable,
|
||||
)
|
||||
assert(repr(parallel.input_schema.validate({ "foo": "Y", "bar": "Z" })) == "RunnableParallel<foo,other>Input(root={'foo': 'Y', 'bar': 'Z'})")
|
Loading…
Reference in New Issue
Block a user