diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 63933394d2b..2dceff9ca30 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -434,14 +434,17 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): if not issubclass(map_input_schema, RootModel) and not issubclass( map_output_schema, RootModel ): - # ie. both are dicts + fields = {} + + for name, field_info in map_input_schema.model_fields.items(): + fields[name] = (field_info.annotation, field_info.default) + + for name, field_info in map_output_schema.model_fields.items(): + fields[name] = (field_info.annotation, field_info.default) + return create_model( # type: ignore[call-overload] "RunnableAssignOutput", - **{ - k: (v.type_, v.default) - for s in (map_input_schema, map_output_schema) - for k, v in s.model_fields.items() - }, + **fields, ) elif not issubclass(map_output_schema, RootModel): # ie. only map output is a dict