Ensure config defaults apply even when a config is passed in

This commit is contained in:
Nuno Campos 2023-08-18 10:02:29 +01:00
parent 7fe474d198
commit c1b1666ec8
2 changed files with 32 additions and 22 deletions

View File

@ -35,8 +35,11 @@ from langchain.schema.runnable.utils import (
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
def _empty_config() -> RunnableConfig: def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
if config is not None:
empty.update(config)
return empty
def _get_callback_manager(config: Mapping) -> Any: def _get_callback_manager(config: Mapping) -> Any:
@ -238,10 +241,7 @@ class Runnable(Generic[Input, Output], ABC):
return ( return (
config config
if isinstance(config, list) if isinstance(config, list)
else [ else [deepcopy(_ensure_config(config)) for _ in range(length)]
deepcopy(config) if config is not None else _empty_config()
for _ in range(length)
]
) )
def _call_with_config( def _call_with_config(
@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses.""" with callbacks. Use this method to implement ainvoke() in subclasses."""
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None final_output: Optional[Output] = None
final_output_supported = True final_output_supported = True
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), dumpd(self),
@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None final_output: Optional[Output] = None
final_output_supported = True final_output_supported = True
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
dumpd(self), dumpd(self),
@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks # setup callbacks
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:
# setup callbacks # setup callbacks
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -724,7 +724,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks # setup callbacks
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -753,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:
# setup callbacks # setup callbacks
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -899,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]: ) -> Iterator[Output]:
# setup callbacks # setup callbacks
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_callback_manager(config) callback_manager = _get_callback_manager(config)
# start the root run # start the root run
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -966,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
# setup callbacks # setup callbacks
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
# setup callbacks # setup callbacks
config = config or _empty_config() config = _ensure_config(config)
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"), inheritable_callbacks=config.get("callbacks"),
local_callbacks=None, local_callbacks=None,
@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# setup callbacks # setup callbacks
config = config or _empty_config() config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config) callback_manager = _get_async_callback_manager(config)
# start the root run # start the root run
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(

View File

@ -134,8 +134,12 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])), mocker.call(
mocker.call("wooorld", dict(tags=["a-tag"])), "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call(
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
] ]
spy.reset_mock() spy.reset_mock()
@ -156,8 +160,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
7, 7,
] ]
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"key": "value"})), mocker.call(
mocker.call("wooorld", dict(metadata={"key": "value"})), "hello",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
mocker.call(
"wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
] ]