mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-08 16:48:49 +00:00
Ensure config defaults apply even when a config is passed in
This commit is contained in:
parent
7fe474d198
commit
c1b1666ec8
@ -35,8 +35,11 @@ from langchain.schema.runnable.utils import (
|
|||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
|
|
||||||
|
|
||||||
def _empty_config() -> RunnableConfig:
|
def _ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={})
|
||||||
|
if config is not None:
|
||||||
|
empty.update(config)
|
||||||
|
return empty
|
||||||
|
|
||||||
|
|
||||||
def _get_callback_manager(config: Mapping) -> Any:
|
def _get_callback_manager(config: Mapping) -> Any:
|
||||||
@ -238,10 +241,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
return (
|
return (
|
||||||
config
|
config
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [
|
else [deepcopy(_ensure_config(config)) for _ in range(length)]
|
||||||
deepcopy(config) if config is not None else _empty_config()
|
|
||||||
for _ in range(length)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call_with_config(
|
def _call_with_config(
|
||||||
@ -253,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
@ -283,7 +283,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> Output:
|
) -> Output:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
with callbacks. Use this method to implement ainvoke() in subclasses."""
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
@ -322,7 +322,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
@ -387,7 +387,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
final_output: Optional[Output] = None
|
final_output: Optional[Output] = None
|
||||||
final_output_supported = True
|
final_output_supported = True
|
||||||
|
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
dumpd(self),
|
dumpd(self),
|
||||||
@ -462,7 +462,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
@ -495,7 +495,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
@ -724,7 +724,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
@ -753,7 +753,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Output:
|
) -> Output:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
@ -899,7 +899,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_callback_manager(config)
|
callback_manager = _get_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
@ -966,7 +966,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
@ -1068,7 +1068,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
inheritable_callbacks=config.get("callbacks"),
|
inheritable_callbacks=config.get("callbacks"),
|
||||||
local_callbacks=None,
|
local_callbacks=None,
|
||||||
@ -1108,7 +1108,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
config = config or _empty_config()
|
config = _ensure_config(config)
|
||||||
callback_manager = _get_async_callback_manager(config)
|
callback_manager = _get_async_callback_manager(config)
|
||||||
# start the root run
|
# start the root run
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
|
@ -134,8 +134,12 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
|
||||||
assert spy.call_args_list == [
|
assert spy.call_args_list == [
|
||||||
mocker.call("hello", dict(tags=["a-tag"])),
|
mocker.call(
|
||||||
mocker.call("wooorld", dict(tags=["a-tag"])),
|
"hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||||
|
),
|
||||||
|
mocker.call(
|
||||||
|
"wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={})
|
||||||
|
),
|
||||||
]
|
]
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
@ -156,8 +160,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
7,
|
7,
|
||||||
]
|
]
|
||||||
assert spy.call_args_list == [
|
assert spy.call_args_list == [
|
||||||
mocker.call("hello", dict(metadata={"key": "value"})),
|
mocker.call(
|
||||||
mocker.call("wooorld", dict(metadata={"key": "value"})),
|
"hello",
|
||||||
|
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||||
|
),
|
||||||
|
mocker.call(
|
||||||
|
"wooorld",
|
||||||
|
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user