diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 9968595677f..803143df6ad 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2425,7 +2425,11 @@ class RunnableBinding(RunnableSerializable[Input, Output]): def bind(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( - bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} + bound=self.bound, + config=self.config, + kwargs={**self.kwargs, **kwargs}, + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, ) def with_config( @@ -2438,6 +2442,8 @@ class RunnableBinding(RunnableSerializable[Input, Output]): bound=self.bound, kwargs=self.kwargs, config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}), + custom_input_type=self.custom_input_type, + custom_output_type=self.custom_output_type, ) def with_types(