diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 0dbabd1579d..16f99324b0f 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -1,3 +1,4 @@ +from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.base import ( Runnable, RunnableBinding, @@ -11,6 +12,8 @@ from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable __all__ = [ + "GetLocalVar", + "PutLocalVar", "RouterInput", "RouterRunnable", "Runnable", diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py new file mode 100644 index 00000000000..755a709fc95 --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/_locals.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) +from langchain.load.serializable import Serializable +from langchain.schema.runnable.base import Input, Output, Runnable +from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.passthrough import RunnablePassthrough + + +class PutLocalVar(RunnablePassthrough): + key: Union[str, Mapping[str, str]] + """The key(s) to use for storing the input variable(s) in local state. + + If a string is provided then the entire input is stored under that key. If a + Mapping is provided, then the map values are gotten from the input and + stored in local state under the map keys. + """ + + def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def _concat_put( + self, + input: Input, + *, + config: Optional[RunnableConfig] = None, + replace: bool = False, + ) -> None: + if config is None: + raise ValueError( + "PutLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + if isinstance(self.key, str): + if self.key not in config["_locals"] or replace: + config["_locals"][self.key] = input + else: + config["_locals"][self.key] += input + elif isinstance(self.key, Mapping): + if not isinstance(input, Mapping): + raise TypeError( + f"Received key of type Mapping but input of type {type(input)}. " + f"input is expected to be of type Mapping when key is Mapping." + ) + for input_key, put_key in self.key.items(): + if put_key not in config["_locals"] or replace: + config["_locals"][put_key] = input[input_key] + else: + config["_locals"][put_key] += input[input_key] + else: + raise TypeError( + f"`key` should be a string or Mapping[str, str], received type " + f"{(type(self.key))}." + ) + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: + self._concat_put(input, config=config, replace=True) + return super().invoke(input, config=config) + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Input: + self._concat_put(input, config=config, replace=True) + return await super().ainvoke(input, config=config) + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Input]: + for chunk in super().transform(input, config=config): + self._concat_put(chunk, config=config) + yield chunk + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Input]: + async for chunk in super().atransform(input, config=config): + self._concat_put(chunk, config=config) + yield chunk + + +class GetLocalVar( + Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]] +): + key: str + """The key to extract from the local state.""" + passthrough_key: Optional[str] = None + """The key to use for passing through the invocation input. + + If None, then only the value retrieved from local state is returned. Otherwise a + dictionary ``{self.key: <>, self.passthrough_key: <>}`` + is returned. + """ + + def __init__(self, key: str, **kwargs: Any) -> None: + super().__init__(key=key, **kwargs) + + def _get( + self, + input: Input, + run_manager: Union[CallbackManagerForChainRun, Any], + config: RunnableConfig, + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + if self.passthrough_key: + return { + self.key: config["_locals"][self.key], + self.passthrough_key: input, + } + else: + return config["_locals"][self.key] + + async def _aget( + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + return self._get(input, run_manager, config) + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + if config is None: + raise ValueError( + "GetLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + + return self._call_with_config(self._get, input, config) + + async def ainvoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Union[Output, Dict[str, Union[Input, Output]]]: + if config is None: + raise ValueError( + "GetLocalVar should only be used in a RunnableSequence, and should " + "therefore always receive a non-null config." + ) + + return await self._acall_with_config(self._aget, input, config) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 87490cdd171..aab395c46c8 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -5,6 +5,7 @@ import copy import threading from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait +from copy import deepcopy from functools import partial from itertools import tee from typing import ( @@ -34,11 +35,16 @@ if TYPE_CHECKING: ) -from langchain.callbacks.base import BaseCallbackManager from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field -from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.config import ( + RunnableConfig, + ensure_config, + get_async_callback_manager_for_config, + get_callback_manager_for_config, + patch_config, +) from langchain.schema.runnable.utils import ( accepts_run_manager, accepts_run_manager_and_config, @@ -238,9 +244,9 @@ class Runnable(Generic[Input, Output], ABC): ) return ( - config + list(map(ensure_config, config)) if isinstance(config, list) - else [config.copy() if config is not None else {} for _ in range(length)] + else [deepcopy(ensure_config(config)) for _ in range(length)] ) def _call_with_config( @@ -256,14 +262,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - from langchain.callbacks.manager import CallbackManager - - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( dumpd(self), input, @@ -303,14 +303,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - from langchain.callbacks.manager import AsyncCallbackManager - - config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( dumpd(self), input, @@ -358,8 +352,6 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Iterator of Input values into an Iterator of Output values, with callbacks. Use this to implement `stream()` or `transform()` in Runnable subclasses.""" - from langchain.callbacks.manager import CallbackManager - # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = tee(input, 2) # Start the input iterator to ensure the input runnable starts before this one @@ -368,12 +360,8 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -444,8 +432,6 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" - from langchain.callbacks.manager import AsyncCallbackManager - # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = atee(input, 2) # Start the input iterator to ensure the input runnable starts before this one @@ -454,12 +440,8 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -535,19 +517,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): yield from self.fallbacks def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - from langchain.callbacks.manager import CallbackManager - # setup callbacks - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start(dumpd(self), input) first_error = None @@ -577,19 +549,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks - config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start(dumpd(self), input) @@ -808,19 +770,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - from langchain.callbacks.manager import CallbackManager - # setup callbacks - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start(dumpd(self), input) @@ -846,19 +798,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks - config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start(dumpd(self), input) @@ -993,19 +935,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: - from langchain.callbacks.manager import CallbackManager - # setup callbacks - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) # start the root run run_manager = callback_manager.on_chain_start(dumpd(self), input) @@ -1069,19 +1001,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks - config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start(dumpd(self), input) @@ -1193,7 +1115,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): from langchain.callbacks.manager import CallbackManager # setup callbacks - config = config or {} + config = ensure_config(config) callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, @@ -1216,7 +1138,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): step.invoke, input, # mark each step as a child run - patch_config(config, run_manager.get_child()), + patch_config(deepcopy(config), run_manager.get_child()), ) for step in steps.values() ] @@ -1235,19 +1157,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Dict[str, Any]: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks - config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) # start the root run run_manager = await callback_manager.on_chain_start(dumpd(self), input) @@ -1540,14 +1452,6 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): yield item -def patch_config( - config: RunnableConfig, callback_manager: BaseCallbackManager -) -> RunnableConfig: - config = config.copy() - config["callbacks"] = callback_manager - return config - - def coerce_to_runnable( thing: Union[ Runnable[Input, Output], diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 715b79fd9f6..00408b7ee6c 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, List, Optional, TypedDict -from langchain.callbacks.base import Callbacks +from langchain.callbacks.base import BaseCallbackManager, Callbacks +from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager class RunnableConfig(TypedDict, total=False): @@ -25,3 +26,42 @@ class RunnableConfig(TypedDict, total=False): Callbacks for this call and any sub-calls (eg. a Chain calling an LLM). Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. """ + + _locals: Dict[str, Any] + """ + Local variables + """ + + +def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: + empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) + if config is not None: + empty.update(config) + return empty + + +def patch_config( + config: RunnableConfig, + callbacks: BaseCallbackManager, +) -> RunnableConfig: + config = config.copy() + config["callbacks"] = callbacks + return config + + +def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: + return CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + + +def get_async_callback_manager_for_config( + config: RunnableConfig, +) -> AsyncCallbackManager: + return AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 420b13fe802..d5d7c152c15 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -47,10 +47,11 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): ) -> Iterator[Input]: return self._transform_stream_with_config(input, identity, config) - def atransform( + async def atransform( self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> AsyncIterator[Input]: - return self._atransform_stream_with_config(input, identity, config) + async for chunk in self._atransform_stream_with_config(input, identity, config): + yield chunk diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__init__.py b/libs/langchain/tests/unit_tests/schema/runnable/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr similarity index 99% rename from libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr rename to libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 321edb7a294..11e554a9cc4 100644 --- a/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -1352,6 +1352,7 @@ "lc": 1, "type": "not_implemented", "id": [ + "runnable", "test_runnable", "FakeRetriever" ] diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py new file mode 100644 index 00000000000..ee07c0cfc6e --- /dev/null +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -0,0 +1,93 @@ +from typing import Any, Callable, Type + +import pytest + +from langchain import PromptTemplate +from langchain.llms import FakeListLLM +from langchain.schema.runnable import ( + GetLocalVar, + PutLocalVar, + RunnablePassthrough, + RunnableSequence, +) + + +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.invoke(x), "foo", "foo"), + (lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]), + (lambda r, x: list(r.stream(x))[0], "foo", "foo"), + ], +) +def test_put_get(method: Callable, input: Any, output: Any) -> None: + runnable = PutLocalVar("input") | GetLocalVar("input") + assert method(runnable, input) == output + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.ainvoke(x), "foo", "foo"), + (lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]), + ], +) +async def test_put_get_async(method: Callable, input: Any, output: Any) -> None: + runnable = PutLocalVar("input") | GetLocalVar("input") + assert await method(runnable, input) == output + + +@pytest.mark.parametrize( + ("runnable", "error"), + [ + (PutLocalVar("input"), ValueError), + (GetLocalVar("input"), ValueError), + (PutLocalVar("input") | GetLocalVar("missing"), KeyError), + ], +) +def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None: + with pytest.raises(error): + runnable.invoke("foo") + + +def test_get_in_map() -> None: + runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")} + assert runnable.invoke("foo") == {"bar": "foo"} + + +def test_put_in_map() -> None: + runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input") + with pytest.raises(KeyError): + runnable.invoke("foo") + + +@pytest.mark.parametrize( + "runnable", + [ + PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"), + ( + PutLocalVar("input") + | {"input": RunnablePassthrough()} + | PromptTemplate.from_template("say {input}") + | FakeListLLM(responses=["hello"]) + | GetLocalVar("input", passthrough_key="output") + ), + ], +) +@pytest.mark.parametrize( + ("method", "input", "output"), + [ + (lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}), + (lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]), + ( + lambda r, x: list(r.stream(x))[0], + "hello", + {"input": "hello", "output": "hello"}, + ), + ], +) +def test_put_get_sequence( + runnable: RunnableSequence, method: Callable, input: Any, output: Any +) -> None: + assert method(runnable, input) == output diff --git a/libs/langchain/tests/unit_tests/schema/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py similarity index 97% rename from libs/langchain/tests/unit_tests/schema/test_runnable.py rename to libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 352ddea4011..80d63c69123 100644 --- a/libs/langchain/tests/unit_tests/schema/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -132,15 +132,24 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] ) == [5, 7] assert spy.call_args_list == [ - mocker.call("hello", dict(tags=["a-tag"])), - mocker.call("wooorld", dict(metadata={"key": "value"})), + mocker.call( + "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) + ), + mocker.call( + "wooorld", + dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + ), ] spy.reset_mock() assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert spy.call_args_list == [ - mocker.call("hello", dict(tags=["a-tag"])), - mocker.call("wooorld", dict(tags=["a-tag"])), + mocker.call( + "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) + ), + mocker.call( + "wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) + ), ] spy.reset_mock() @@ -161,8 +170,14 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: 7, ] assert spy.call_args_list == [ - mocker.call("hello", dict(metadata={"key": "value"})), - mocker.call("wooorld", dict(metadata={"key": "value"})), + mocker.call( + "hello", + dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + ), + mocker.call( + "wooorld", + dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), + ), ]