Ensure config defaults apply even when a config is passed in

This commit is contained in:
Nuno Campos 2023-08-18 10:02:29 +01:00
parent 7fe474d198
commit c1b1666ec8
2 changed files with 32 additions and 22 deletions

View File

@ -35,8 +35,11 @@ from langchain.schema.runnable.utils import (
from langchain.utils.aiter import atee, py_anext
def _empty_config() -> RunnableConfig:
return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
if config is not None:
empty.update(config)
return empty
def _get_callback_manager(config: Mapping) -> Any:
@ -238,10 +241,7 @@ class Runnable(Generic[Input, Output], ABC):
return (
config
if isinstance(config, list)
else [
deepcopy(config) if config is not None else _empty_config()
for _ in range(length)
]
else [deepcopy(_ensure_config(config)) for _ in range(length)]
)
def _call_with_config(
@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC):
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None
final_output_supported = True
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output: Optional[Output] = None
final_output_supported = True
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
# start the root run
run_manager = callback_manager.on_chain_start(
@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
# setup callbacks
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
@ -724,7 +724,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
# setup callbacks
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
# start the root run
run_manager = callback_manager.on_chain_start(
@ -753,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
# setup callbacks
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
@ -899,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
# setup callbacks
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_callback_manager(config)
# start the root run
run_manager = callback_manager.on_chain_start(
@ -966,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
# setup callbacks
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
# setup callbacks
config = config or _empty_config()
config = _ensure_config(config)
callback_manager = _get_async_callback_manager(config)
# start the root run
run_manager = await callback_manager.on_chain_start(

View File

@ -134,8 +134,12 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
mocker.call("wooorld", dict(tags=["a-tag"])),
mocker.call(
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
mocker.call(
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
),
]
spy.reset_mock()
@ -156,8 +160,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
7,
]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"key": "value"})),
mocker.call("wooorld", dict(metadata={"key": "value"})),
mocker.call(
"hello",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
mocker.call(
"wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
),
]