From 7975c1f0ca4670998d93b773b66540705b4c1e7f Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Tue, 10 Sep 2024 15:22:56 -0400 Subject: [PATCH] core[patch]: Propagate module name to create model (#26267) * This allows pydantic to correctly resolve annotations necessary for building pydantic models dynamically. * Makes a small fix for RunnableWithMessageHistory which was fetching the OutputType from the RunnableLambda that was yielding another RunnableLambda. This doesn't propagate the output of the RunnableAssign fully (i.e., with concrete type information etc.) Resolves issue: https://github.com/langchain-ai/langchain/issues/26250 --- libs/core/langchain_core/runnables/base.py | 104 ++++++++++++++++++ libs/core/langchain_core/runnables/history.py | 35 ++++++ libs/core/langchain_core/runnables/utils.py | 25 ++++- .../unit_tests/runnables/test_history.py | 35 ++++++ 4 files changed, 194 insertions(+), 5 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index f04d34c1648..3ea4325f3c7 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -348,6 +348,14 @@ class Runnable(Generic[Input, Output], ABC): return create_model( self.get_name("Input"), __root__=root_type, + # create model needs access to appropriate type annotations to be + # able to construct the pydantic model. + # When we create the model, we pass information about the namespace + # where the model is being created, so the type annotations can + # be resolved correctly as well. + # self.__class__.__module__ handles the case when the Runnable is + # being sub-classed in a different module. + __module_name=self.__class__.__module__, ) def get_input_jsonschema( @@ -408,6 +416,14 @@ class Runnable(Generic[Input, Output], ABC): return create_model( self.get_name("Output"), __root__=root_type, + # create model needs access to appropriate type annotations to be + # able to construct the pydantic model. + # When we create the model, we pass information about the namespace + # where the model is being created, so the type annotations can + # be resolved correctly as well. + # self.__class__.__module__ handles the case when the Runnable is + # being sub-classed in a different module. + __module_name=self.__class__.__module__, ) def get_output_jsonschema( @@ -4046,6 +4062,29 @@ class RunnableGenerator(Runnable[Input, Output]): except ValueError: return Any + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + # Override the default implementation. + # For a runnable generator, we need to bring to provide the + # module of the underlying function when creating the model. + root_type = self.InputType + + func = getattr(self, "_transform", None) or self._atransform + module = getattr(func, "__module__", None) + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + self.get_name("Input"), + __root__=root_type, + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + __module_name=module, + ) + @property def OutputType(self) -> Any: func = getattr(self, "_transform", None) or self._atransform @@ -4059,6 +4098,28 @@ class RunnableGenerator(Runnable[Input, Output]): except ValueError: return Any + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + # Override the default implementation. + # For a runnable generator, we need to bring to provide the + # module of the underlying function when creating the model. + root_type = self.OutputType + func = getattr(self, "_transform", None) or self._atransform + module = getattr(func, "__module__", None) + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + self.get_name("Output"), + __root__=root_type, + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + __module_name=module, + ) + def __eq__(self, other: Any) -> bool: if isinstance(other, RunnableGenerator): if hasattr(self, "_transform") and hasattr(other, "_transform"): @@ -4307,9 +4368,14 @@ class RunnableLambda(Runnable[Input, Output]): # It's a dict, lol return create_model(self.get_name("Input"), **fields) else: + module = getattr(func, "__module__", None) return create_model( self.get_name("Input"), __root__=List[Any], + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + __module_name=module, ) if self.InputType != Any: @@ -4346,6 +4412,28 @@ class RunnableLambda(Runnable[Input, Output]): except ValueError: return Any + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + # Override the default implementation. + # For a runnable lambda, we need to bring to provide the + # module of the underlying function when creating the model. + root_type = self.OutputType + func = getattr(self, "func", None) or self.afunc + module = getattr(func, "__module__", None) + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + self.get_name("Output"), + __root__=root_type, + # To create the schema, we need to provide the module + # where the underlying function is defined. + # This allows pydantic to resolve type annotations appropriately. + __module_name=module, + ) + @property def deps(self) -> List[Runnable]: """The dependencies of this Runnable. @@ -4863,6 +4951,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): List[self.bound.get_input_schema(config)], # type: ignore None, ), + # create model needs access to appropriate type annotations to be + # able to construct the pydantic model. + # When we create the model, we pass information about the namespace + # where the model is being created, so the type annotations can + # be resolved correctly as well. + # self.__class__.__module__ handles the case when the Runnable is + # being sub-classed in a different module. + __module_name=self.__class__.__module__, ) @property @@ -4876,6 +4972,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): return create_model( self.get_name("Output"), __root__=List[schema], # type: ignore[valid-type] + # create model needs access to appropriate type annotations to be + # able to construct the pydantic model. + # When we create the model, we pass information about the namespace + # where the model is being created, so the type annotations can + # be resolved correctly as well. + # self.__class__.__module__ handles the case when the Runnable is + # being sub-classed in a different module. + __module_name=self.__class__.__module__, ) @property diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 8b9ad729206..e6db922730f 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -21,6 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableBindingBase, Runnabl from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.utils import ( ConfigurableFieldSpec, + Output, create_model, get_unique_config_specs, ) @@ -362,6 +363,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): history_factory_config=_config_specs, **kwargs, ) + self._history_chain = history_chain @property def config_specs(self) -> List[ConfigurableFieldSpec]: @@ -393,6 +395,39 @@ class RunnableWithMessageHistory(RunnableBindingBase): **fields, ) + @property + def OutputType(self) -> Type[Output]: + output_type = self._history_chain.OutputType + return output_type + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + """Get a pydantic model that can be used to validate output to the Runnable. + + Runnables that leverage the configurable_fields and configurable_alternatives + methods will have a dynamic output schema that depends on which + configuration the Runnable is invoked with. + + This method allows to get an output schema for a specific configuration. + + Args: + config: A config to use when generating the schema. + + Returns: + A pydantic model that can be used to validate output. + """ + root_type = self.OutputType + + if inspect.isclass(root_type) and issubclass(root_type, BaseModel): + return root_type + + return create_model( + "RunnableWithChatHistoryOutput", + __root__=root_type, + __module_name=self.__class__.__module__, + ) + def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: return False diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index e1316f0dd08..eb03d11346f 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -713,7 +713,10 @@ NO_DEFAULT = object() def _create_root_model( - name: str, type_: Any, default_: object = NO_DEFAULT + name: str, + type_: Any, + module_name: Optional[str] = None, + default_: object = NO_DEFAULT, ) -> Type[BaseModel]: """Create a base class.""" @@ -751,7 +754,7 @@ def _create_root_model( "model_config": ConfigDict(arbitrary_types_allowed=True), "schema": classmethod(schema), "model_json_schema": classmethod(model_json_schema), - "__module__": "langchain_core.runnables.utils", + "__module__": module_name or "langchain_core.runnables.utils", } if default_ is not NO_DEFAULT: @@ -770,18 +773,24 @@ def _create_root_model_cached( __model_name: str, type_: Any, default_: object = NO_DEFAULT, + module_name: Optional[str] = None, ) -> Type[BaseModel]: - return _create_root_model(__model_name, type_, default_) + return _create_root_model( + __model_name, type_, default_=default_, module_name=module_name + ) def create_model( __model_name: str, + __module_name: Optional[str] = None, **field_definitions: Any, ) -> Type[BaseModel]: """Create a pydantic model with the given field definitions. Args: __model_name: The name of the model. + __module_name: The name of the module where the model is defined. + This is used by Pydantic to resolve any forward references. **field_definitions: The field definitions for the model. Returns: @@ -803,10 +812,16 @@ def create_model( kwargs = {"type_": arg} try: - named_root_model = _create_root_model_cached(__model_name, **kwargs) + named_root_model = _create_root_model_cached( + __model_name, module_name=__module_name, **kwargs + ) except TypeError: # something in the arguments into _create_root_model_cached is not hashable - named_root_model = _create_root_model(__model_name, **kwargs) + named_root_model = _create_root_model( + __model_name, + module_name=__module_name, + **kwargs, + ) return named_root_model try: return _create_model_cached(__model_name, **field_definitions) diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 22ee2aa1795..377e7a72457 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -454,6 +454,41 @@ def test_get_input_schema_input_dict() -> None: ) +def test_get_output_schema() -> None: + """Test get output schema.""" + runnable = RunnableLambda( + lambda input: { + "output": [ + AIMessage( + content="you said: " + + "\n".join( + [ + str(m.content) + for m in input["history"] + if isinstance(m, HumanMessage) + ] + + [input["input"]] + ) + ) + ] + } + ) + get_session_history = _get_get_session_history() + with_history = RunnableWithMessageHistory( + runnable, + get_session_history, + input_messages_key="input", + history_messages_key="history", + output_messages_key="output", + ) + output_type = with_history.get_output_schema() + + assert _schema(output_type) == { + "title": "RunnableWithChatHistoryOutput", + "type": "object", + } + + def test_get_input_schema_input_messages() -> None: from pydantic import RootModel