diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py index a794ea6d5c0..cfe1f76aa07 100644 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ b/libs/langchain/langchain/schema/runnable/_locals.py @@ -68,9 +68,11 @@ class PutLocalVar(RunnablePassthrough): f"{(type(self.key))}." ) - def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other: + def invoke( + self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Other: self._concat_put(input, config=config, replace=True) - return super().invoke(input, config=config) + return super().invoke(input, config=config, **kwargs) async def ainvoke( self, @@ -79,7 +81,7 @@ class PutLocalVar(RunnablePassthrough): **kwargs: Optional[Any], ) -> Other: self._concat_put(input, config=config, replace=True) - return await super().ainvoke(input, config=config) + return await super().ainvoke(input, config=config, **kwargs) def transform( self, @@ -87,7 +89,7 @@ class PutLocalVar(RunnablePassthrough): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Other]: - for chunk in super().transform(input, config=config): + for chunk in super().transform(input, config=config, **kwargs): self._concat_put(chunk, config=config) yield chunk @@ -97,7 +99,7 @@ class PutLocalVar(RunnablePassthrough): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Other]: - async for chunk in super().atransform(input, config=config): + async for chunk in super().atransform(input, config=config, **kwargs): self._concat_put(chunk, config=config) yield chunk diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 803143df6ad..b9a3fff16e3 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -666,7 +666,7 @@ class Runnable(Generic[Input, Output], ABC): ) try: output = call_func_with_variable_args( - func, input, run_manager, config, **kwargs + func, input, config, run_manager, **kwargs ) except BaseException as e: run_manager.on_chain_error(e) @@ -702,7 +702,7 @@ class Runnable(Generic[Input, Output], ABC): ) try: output = await acall_func_with_variable_args( - func, input, run_manager, config, **kwargs + func, input, config, run_manager, **kwargs ) except BaseException as e: await run_manager.on_chain_error(e) @@ -2027,8 +2027,34 @@ class RunnableLambda(Runnable[Input, Output]): def __init__( self, - func: Union[Callable[[Input], Output], Callable[[Input], Awaitable[Output]]], - afunc: Optional[Callable[[Input], Awaitable[Output]]] = None, + func: Union[ + Union[ + Callable[[Input], Output], + Callable[[Input, RunnableConfig], Output], + Callable[[Input, CallbackManagerForChainRun], Output], + Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], + ], + Union[ + Callable[[Input], Awaitable[Output]], + Callable[[Input, RunnableConfig], Awaitable[Output]], + Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], + Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], + Awaitable[Output], + ], + ], + ], + afunc: Optional[ + Union[ + Callable[[Input], Awaitable[Output]], + Callable[[Input, RunnableConfig], Awaitable[Output]], + Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], + Callable[ + [Input, AsyncCallbackManagerForChainRun, RunnableConfig], + Awaitable[Output], + ], + ] + ] = None, ) -> None: """Create a RunnableLambda from a callable, and async callable or both. @@ -2136,7 +2162,7 @@ class RunnableLambda(Runnable[Input, Output]): run_manager: CallbackManagerForChainRun, config: RunnableConfig, ) -> Output: - output = call_func_with_variable_args(self.func, input, run_manager, config) + output = call_func_with_variable_args(self.func, input, config, run_manager) # If the output is a runnable, invoke it if isinstance(output, Runnable): recursion_limit = config["recursion_limit"] @@ -2161,7 +2187,7 @@ class RunnableLambda(Runnable[Input, Output]): config: RunnableConfig, ) -> Output: output = await acall_func_with_variable_args( - self.afunc, input, run_manager, config + self.afunc, input, config, run_manager ) # If the output is a runnable, invoke it if isinstance(output, Runnable): diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 1b720fb5b6e..408339fcfc4 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -185,18 +185,22 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: def call_func_with_variable_args( func: Union[ Callable[[Input], Output], + Callable[[Input, RunnableConfig], Output], Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], ], input: Input, - run_manager: CallbackManagerForChainRun, config: RunnableConfig, + run_manager: Optional[CallbackManagerForChainRun] = None, **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config.""" if accepts_config(func): - kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) - if accepts_run_manager(func): + if run_manager is not None: + kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) + else: + kwargs["config"] = config + if run_manager is not None and accepts_run_manager(func): kwargs["run_manager"] = run_manager return func(input, **kwargs) # type: ignore[call-arg] @@ -204,6 +208,7 @@ def call_func_with_variable_args( async def acall_func_with_variable_args( func: Union[ Callable[[Input], Awaitable[Output]], + Callable[[Input, RunnableConfig], Awaitable[Output]], Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]], Callable[ [Input, AsyncCallbackManagerForChainRun, RunnableConfig], @@ -211,14 +216,17 @@ async def acall_func_with_variable_args( ], ], input: Input, - run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config.""" if accepts_config(func): - kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) - if accepts_run_manager(func): + if run_manager is not None: + kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) + else: + kwargs["config"] = config + if run_manager is not None and accepts_run_manager(func): kwargs["run_manager"] = run_manager return await func(input, **kwargs) # type: ignore[call-arg] diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 6f369d753cd..117edb90740 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -27,7 +27,12 @@ from langchain.schema.runnable.base import ( RunnableParallel, RunnableSerializable, ) -from langchain.schema.runnable.config import RunnableConfig, get_executor_for_config +from langchain.schema.runnable.config import ( + RunnableConfig, + acall_func_with_variable_args, + call_func_with_variable_args, + get_executor_for_config, +) from langchain.schema.runnable.utils import AddableDict, ConfigurableFieldSpec from langchain.utils.aiter import atee, py_anext from langchain.utils.iter import safetee @@ -102,16 +107,34 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): input_type: Optional[Type[Other]] = None - func: Optional[Callable[[Other], None]] = None + func: Optional[ + Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]] + ] = None - afunc: Optional[Callable[[Other], Awaitable[None]]] = None + afunc: Optional[ + Union[ + Callable[[Other], Awaitable[None]], + Callable[[Other, RunnableConfig], Awaitable[None]], + ] + ] = None def __init__( self, func: Optional[ - Union[Callable[[Other], None], Callable[[Other], Awaitable[None]]] + Union[ + Union[Callable[[Other], None], Callable[[Other, RunnableConfig], None]], + Union[ + Callable[[Other], Awaitable[None]], + Callable[[Other, RunnableConfig], Awaitable[None]], + ], + ] + ] = None, + afunc: Optional[ + Union[ + Callable[[Other], Awaitable[None]], + Callable[[Other, RunnableConfig], Awaitable[None]], + ] ] = None, - afunc: Optional[Callable[[Other], Awaitable[None]]] = None, *, input_type: Optional[Type[Other]] = None, **kwargs: Any, @@ -161,9 +184,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): """ return RunnableAssign(RunnableParallel(kwargs)) - def invoke(self, input: Other, config: Optional[RunnableConfig] = None) -> Other: + def invoke( + self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Other: if self.func is not None: - self.func(input) + call_func_with_variable_args(self.func, input, config or {}, **kwargs) return self._call_with_config(identity, input, config) async def ainvoke( @@ -173,9 +198,11 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): **kwargs: Optional[Any], ) -> Other: if self.afunc is not None: - await self.afunc(input, **kwargs) + await acall_func_with_variable_args( + self.afunc, input, config or {}, **kwargs + ) elif self.func is not None: - self.func(input, **kwargs) + call_func_with_variable_args(self.func, input, config or {}, **kwargs) return await self._acall_with_config(aidentity, input, config) def transform( @@ -198,7 +225,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): final = final + chunk if final is not None: - self.func(final, **kwargs) + call_func_with_variable_args(self.func, final, config or {}, **kwargs) async def atransform( self, @@ -224,10 +251,13 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): final = final + chunk if final is not None: + config = config or {} if self.afunc is not None: - await self.afunc(final, **kwargs) + await acall_func_with_variable_args( + self.afunc, final, config, **kwargs + ) elif self.func is not None: - self.func(final, **kwargs) + call_func_with_variable_args(self.func, final, config, **kwargs) def stream( self,