From c1b1666ec850e465bf93bd01d34e09fc457076cc Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 18 Aug 2023 10:02:29 +0100 Subject: [PATCH] Ensure config defaults apply even when a config is passed in --- .../langchain/schema/runnable/base.py | 36 +++++++++---------- .../schema/runnable/test_runnable.py | 18 +++++++--- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index c91456394e2..0d9df2baeea 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -35,8 +35,11 @@ from langchain.schema.runnable.utils import ( from langchain.utils.aiter import atee, py_anext -def _empty_config() -> RunnableConfig: - return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) +def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: + empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) + if config is not None: + empty.update(config) + return empty def _get_callback_manager(config: Mapping) -> Any: @@ -238,10 +241,7 @@ class Runnable(Generic[Input, Output], ABC): return ( config if isinstance(config, list) - else [ - deepcopy(config) if config is not None else _empty_config() - for _ in range(length) - ] + else [deepcopy(_ensure_config(config)) for _ in range(length)] ) def _call_with_config( @@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), @@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), @@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), @@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), @@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -724,7 +724,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -753,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -899,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Iterator[Output]: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( @@ -966,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None ) -> AsyncIterator[Output]: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( @@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): from langchain.callbacks.manager import CallbackManager # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, @@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): self, input: Input, config: Optional[RunnableConfig] = None ) -> Dict[str, Any]: # setup callbacks - config = config or _empty_config() + config = _ensure_config(config) callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index c0cae4d9bdf..8bfecb1821d 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -134,8 +134,12 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert spy.call_args_list == [ - mocker.call("hello", dict(tags=["a-tag"])), - mocker.call("wooorld", dict(tags=["a-tag"])), + mocker.call( + "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) + ), + mocker.call( + "wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) + ), ] spy.reset_mock() @@ -156,8 +160,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: 7, ] assert spy.call_args_list == [ - mocker.call("hello", dict(metadata={"key": "value"})), - mocker.call("wooorld", dict(metadata={"key": "value"})), + mocker.call( + "hello", + dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + ), + mocker.call( + "wooorld", + dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + ), ]