diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index f04d34c1648..3ea4325f3c7 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -348,6 +348,14 @@ class Runnable(Generic[Input, Output], ABC): return create_model( self.get_name("Input"), __root__=root_type, + # create model needs access to appropriate type annotations to be + # able to construct the pydantic model. + # When we create the model, we pass information about the namespace + # where the model is being created, so the type annotations can + # be resolved correctly as well. + # self.__class__.__module__ handles the case when the Runnable is + # being sub-classed in a different module. + __module_name=self.__class__.__module__, ) def get_input_jsonschema( @@ -408,6 +416,14 @@ class Runnable(Generic[Input, Output], ABC): return create_model( self.get_name("Output"), __root__=root_type, + # create model needs access to appropriate type annotations to be + # able to construct the pydantic model. + # When we create the model, we pass information about the namespace + # where the model is being created, so the type annotations can + # be resolved correctly as well. + # self.__class__.__module__ handles the case when the Runnable is + # being sub-classed in a different module. + __module_name=self.__class__.__module__, ) def get_output_jsonschema( @@ -4046,6 +4062,29 @@ class RunnableGenerator(Runnable[Input, Output]): except ValueError: return Any + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + # Override the default implementation. + # For a runnable generator, we need to bring to provide the + # module of the underlying function when creating the model. + root_type = self.InputType + + func = getattr(self, "_transform", None) or self._atransform + module = getattr(func, "__module__", None) + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + self.get_name("Input"), + __root__=root_type, + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + __module_name=module, + ) + @property def OutputType(self) -> Any: func = getattr(self, "_transform", None) or self._atransform @@ -4059,6 +4098,28 @@ class RunnableGenerator(Runnable[Input, Output]): except ValueError: return Any + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + # Override the default implementation. + # For a runnable generator, we need to bring to provide the + # module of the underlying function when creating the model. + root_type = self.OutputType + func = getattr(self, "_transform", None) or self._atransform + module = getattr(func, "__module__", None) + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + self.get_name("Output"), + __root__=root_type, + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + __module_name=module, + ) + def __eq__(self, other: Any) -> bool: if isinstance(other, RunnableGenerator): if hasattr(self, "_transform") and hasattr(other, "_transform"): @@ -4307,9 +4368,14 @@ class RunnableLambda(Runnable[Input, Output]): # It's a dict, lol return create_model(self.get_name("Input"), **fields) else: + module = getattr(func, "__module__", None) return create_model( self.get_name("Input"), __root__=List[Any], + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + __module_name=module, ) if self.InputType != Any: @@ -4346,6 +4412,28 @@ class RunnableLambda(Runnable[Input, Output]): except ValueError: return Any + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + # Override the default implementation. + # For a runnable lambda, we need to bring to provide the + # module of the underlying function when creating the model. + root_type = self.OutputType + func = getattr(self, "func", None) or self.afunc + module = getattr(func, "__module__", None) + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + self.get_name("Output"), + __root__=root_type, + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + __module_name=module, + ) + @property def deps(self) -> List[Runnable]: """The dependencies of this Runnable. @@ -4863,6 +4951,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): List[self.bound.get_input_schema(config)], # type: ignore None, ), + # create model needs access to appropriate type annotations to be + # able to construct the pydantic model. + # When we create the model, we pass information about the namespace + # where the model is being created, so the type annotations can + # be resolved correctly as well. + # self.__class__.__module__ handles the case when the Runnable is + # being sub-classed in a different module. + __module_name=self.__class__.__module__, ) @property @@ -4876,6 +4972,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): return create_model( self.get_name("Output"), __root__=List[schema], # type: ignore[valid-type] + # create model needs access to appropriate type annotations to be + # able to construct the pydantic model. + # When we create the model, we pass information about the namespace + # where the model is being created, so the type annotations can + # be resolved correctly as well. + # self.__class__.__module__ handles the case when the Runnable is + # being sub-classed in a different module. + __module_name=self.__class__.__module__, ) @property diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 8b9ad729206..e6db922730f 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -21,6 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableBindingBase, Runnabl from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.utils import ( ConfigurableFieldSpec, + Output, create_model, get_unique_config_specs, ) @@ -362,6 +363,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): history_factory_config=_config_specs, **kwargs, ) + self._history_chain = history_chain @property def config_specs(self) -> List[ConfigurableFieldSpec]: @@ -393,6 +395,39 @@ class RunnableWithMessageHistory(RunnableBindingBase): **fields, ) + @property + def OutputType(self) -> Type[Output]: + output_type = self._history_chain.OutputType + return output_type + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + """Get a pydantic model that can be used to validate output to the Runnable. + + Runnables that leverage the configurable_fields and configurable_alternatives + methods will have a dynamic output schema that depends on which + configuration the Runnable is invoked with. + + This method allows to get an output schema for a specific configuration. + + Args: + config: A config to use when generating the schema. + + Returns: + A pydantic model that can be used to validate output. + """ + root_type = self.OutputType + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + "RunnableWithChatHistoryOutput", + __root__=root_type, + __module_name=self.__class__.__module__, + ) + def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: return False diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index e1316f0dd08..eb03d11346f 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -713,7 +713,10 @@ NO_DEFAULT = object() def _create_root_model( - name: str, type_: Any, default_: object = NO_DEFAULT + name: str, + type_: Any, + module_name: Optional[str] = None, + default_: object = NO_DEFAULT, ) -> Type[BaseModel]: """Create a base class.""" @@ -751,7 +754,7 @@ def _create_root_model( "model_config": ConfigDict(arbitrary_types_allowed=True), "schema": classmethod(schema), "model_json_schema": classmethod(model_json_schema), - "__module__": "langchain_core.runnables.utils", + "__module__": module_name or "langchain_core.runnables.utils", } if default_ is not NO_DEFAULT: @@ -770,18 +773,24 @@ def _create_root_model_cached( __model_name: str, type_: Any, default_: object = NO_DEFAULT, + module_name: Optional[str] = None, ) -> Type[BaseModel]: - return _create_root_model(__model_name, type_, default_) + return _create_root_model( + __model_name, type_, default_=default_, module_name=module_name + ) def create_model( __model_name: str, + __module_name: Optional[str] = None, **field_definitions: Any, ) -> Type[BaseModel]: """Create a pydantic model with the given field definitions. Args: __model_name: The name of the model. + __module_name: The name of the module where the model is defined. + This is used by Pydantic to resolve any forward references. **field_definitions: The field definitions for the model. Returns: @@ -803,10 +812,16 @@ def create_model( kwargs = {"type_": arg} try: - named_root_model = _create_root_model_cached(__model_name, **kwargs) + named_root_model = _create_root_model_cached( + __model_name, module_name=__module_name, **kwargs + ) except TypeError: # something in the arguments into _create_root_model_cached is not hashable - named_root_model = _create_root_model(__model_name, **kwargs) + named_root_model = _create_root_model( + __model_name, + module_name=__module_name, + **kwargs, + ) return named_root_model try: return _create_model_cached(__model_name, **field_definitions) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 22ee2aa1795..377e7a72457 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -454,6 +454,41 @@ def test_get_input_schema_input_dict() -> None: ) +def test_get_output_schema() -> None: + """Test get output schema.""" + runnable = RunnableLambda( + lambda input: { + "output": [ + AIMessage( + content="you said: " + + "\n".join( + [ + str(m.content) + for m in input["history"] + if isinstance(m, HumanMessage) + ] + + [input["input"]] + ) + ) + ] + } + ) + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, + get_session_history, + input_messages_key="input", + history_messages_key="history", + output_messages_key="output", + ) + output_type = with_history.get_output_schema() + + assert _schema(output_type) == { + "title": "RunnableWithChatHistoryOutput", + "type": "object", + } + + def test_get_input_schema_input_messages() -> None: from pydantic import RootModel