mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-07 16:18:09 +00:00
Ensure config defaults apply even when a config is passed in
This commit is contained in:
parent
7fe474d198
commit
c1b1666ec8
@ -35,8 +35,11 @@ from langchain.schema.runnable.utils import (
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
|
||||
|
||||
def _empty_config() -> RunnableConfig:
|
||||
return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||
def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||
if config is not None:
|
||||
empty.update(config)
|
||||
return empty
|
||||
|
||||
|
||||
def _get_callback_manager(config: Mapping) -> Any:
|
||||
@ -238,10 +241,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return (
|
||||
config
|
||||
if isinstance(config, list)
|
||||
else [
|
||||
deepcopy(config) if config is not None else _empty_config()
|
||||
for _ in range(length)
|
||||
]
|
||||
else [deepcopy(_ensure_config(config)) for _ in range(length)]
|
||||
)
|
||||
|
||||
def _call_with_config(
|
||||
@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> Output:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
final_output: Optional[Output] = None
|
||||
final_output_supported = True
|
||||
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
# setup callbacks
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@ -724,7 +724,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
# setup callbacks
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -753,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
# setup callbacks
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@ -899,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
# setup callbacks
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -966,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
# setup callbacks
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
# setup callbacks
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_callbacks=None,
|
||||
@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Dict[str, Any]:
|
||||
# setup callbacks
|
||||
config = config or _empty_config()
|
||||
config = _ensure_config(config)
|
||||
callback_manager = _get_async_callback_manager(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
|
@ -134,8 +134,12 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
|
||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"])),
|
||||
mocker.call("wooorld", dict(tags=["a-tag"])),
|
||||
mocker.call(
|
||||
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||
),
|
||||
mocker.call(
|
||||
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||
),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
@ -156,8 +160,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
7,
|
||||
]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(metadata={"key": "value"})),
|
||||
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
||||
mocker.call(
|
||||
"hello",
|
||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||
),
|
||||
mocker.call(
|
||||
"wooorld",
|
||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user