diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 0e213478071..89500b79020 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2437,8 +2437,11 @@ class RunnableLambda(Runnable[Input, Output]): input: Input, run_manager: CallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: - output = call_func_with_variable_args(self.func, input, config, run_manager) + output = call_func_with_variable_args( + self.func, input, config, run_manager, **kwargs + ) # If the output is a runnable, invoke it if isinstance(output, Runnable): recursion_limit = config["recursion_limit"] @@ -2461,9 +2464,10 @@ class RunnableLambda(Runnable[Input, Output]): input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: output = await acall_func_with_variable_args( - self.afunc, input, config, run_manager + self.afunc, input, config, run_manager, **kwargs ) # If the output is a runnable, invoke it if isinstance(output, Runnable): @@ -2509,6 +2513,7 @@ class RunnableLambda(Runnable[Input, Output]): self._invoke, input, self._config(config, self.func), + **kwargs, ) else: raise TypeError( @@ -2528,6 +2533,7 @@ class RunnableLambda(Runnable[Input, Output]): self._ainvoke, input, self._config(config, self.afunc), + **kwargs, ) else: # Delegating to super implementation of ainvoke.