diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 16f99324b0f..0dbabd1579d 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -1,4 +1,3 @@ -from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.base import ( Runnable, RunnableBinding, @@ -12,8 +11,6 @@ from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable __all__ = [ - "GetLocalVar", - "PutLocalVar", "RouterInput", "RouterRunnable", "Runnable", diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py deleted file mode 100644 index 755a709fc95..00000000000 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ /dev/null @@ -1,156 +0,0 @@ -from __future__ import annotations - -from typing import Any, AsyncIterator, Dict, Iterator, Mapping, Optional, Union - -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import Input, Output, Runnable -from langchain.schema.runnable.config import RunnableConfig -from langchain.schema.runnable.passthrough import RunnablePassthrough - - -class PutLocalVar(RunnablePassthrough): - key: Union[str, Mapping[str, str]] - """The key(s) to use for storing the input variable(s) in local state. - - If a string is provided then the entire input is stored under that key. If a - Mapping is provided, then the map values are gotten from the input and - stored in local state under the map keys. - """ - - def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None: - super().__init__(key=key, **kwargs) - - def _concat_put( - self, - input: Input, - *, - config: Optional[RunnableConfig] = None, - replace: bool = False, - ) -> None: - if config is None: - raise ValueError( - "PutLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - if isinstance(self.key, str): - if self.key not in config["_locals"] or replace: - config["_locals"][self.key] = input - else: - config["_locals"][self.key] += input - elif isinstance(self.key, Mapping): - if not isinstance(input, Mapping): - raise TypeError( - f"Received key of type Mapping but input of type {type(input)}. " - f"input is expected to be of type Mapping when key is Mapping." - ) - for input_key, put_key in self.key.items(): - if put_key not in config["_locals"] or replace: - config["_locals"][put_key] = input[input_key] - else: - config["_locals"][put_key] += input[input_key] - else: - raise TypeError( - f"`key` should be a string or Mapping[str, str], received type " - f"{(type(self.key))}." - ) - - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: - self._concat_put(input, config=config, replace=True) - return super().invoke(input, config=config) - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Input: - self._concat_put(input, config=config, replace=True) - return await super().ainvoke(input, config=config) - - def transform( - self, - input: Iterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Iterator[Input]: - for chunk in super().transform(input, config=config): - self._concat_put(chunk, config=config) - yield chunk - - async def atransform( - self, - input: AsyncIterator[Input], - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> AsyncIterator[Input]: - async for chunk in super().atransform(input, config=config): - self._concat_put(chunk, config=config) - yield chunk - - -class GetLocalVar( - Serializable, Runnable[Input, Union[Output, Dict[str, Union[Input, Output]]]] -): - key: str - """The key to extract from the local state.""" - passthrough_key: Optional[str] = None - """The key to use for passing through the invocation input. - - If None, then only the value retrieved from local state is returned. Otherwise a - dictionary ``{self.key: <>, self.passthrough_key: <>}`` - is returned. - """ - - def __init__(self, key: str, **kwargs: Any) -> None: - super().__init__(key=key, **kwargs) - - def _get( - self, - input: Input, - run_manager: Union[CallbackManagerForChainRun, Any], - config: RunnableConfig, - ) -> Union[Output, Dict[str, Union[Input, Output]]]: - if self.passthrough_key: - return { - self.key: config["_locals"][self.key], - self.passthrough_key: input, - } - else: - return config["_locals"][self.key] - - async def _aget( - self, - input: Input, - run_manager: AsyncCallbackManagerForChainRun, - config: RunnableConfig, - ) -> Union[Output, Dict[str, Union[Input, Output]]]: - return self._get(input, run_manager, config) - - def invoke( - self, input: Input, config: Optional[RunnableConfig] = None - ) -> Union[Output, Dict[str, Union[Input, Output]]]: - if config is None: - raise ValueError( - "GetLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - - return self._call_with_config(self._get, input, config) - - async def ainvoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Optional[Any], - ) -> Union[Output, Dict[str, Union[Input, Output]]]: - if config is None: - raise ValueError( - "GetLocalVar should only be used in a RunnableSequence, and should " - "therefore always receive a non-null config." - ) - - return await self._acall_with_config(self._aget, input, config) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index aab395c46c8..87490cdd171 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -5,7 +5,6 @@ import copy import threading from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait -from copy import deepcopy from functools import partial from itertools import tee from typing import ( @@ -35,16 +34,11 @@ if TYPE_CHECKING: ) +from langchain.callbacks.base import BaseCallbackManager from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field -from langchain.schema.runnable.config import ( - RunnableConfig, - ensure_config, - get_async_callback_manager_for_config, - get_callback_manager_for_config, - patch_config, -) +from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.utils import ( accepts_run_manager, accepts_run_manager_and_config, @@ -244,9 +238,9 @@ class Runnable(Generic[Input, Output], ABC): ) return ( - list(map(ensure_config, config)) + config if isinstance(config, list) - else [deepcopy(ensure_config(config)) for _ in range(length)] + else [config.copy() if config is not None else {} for _ in range(length)] ) def _call_with_config( @@ -262,8 +256,14 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) + from langchain.callbacks.manager import CallbackManager + + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) run_manager = callback_manager.on_chain_start( dumpd(self), input, @@ -303,8 +303,14 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) + from langchain.callbacks.manager import AsyncCallbackManager + + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) run_manager = await callback_manager.on_chain_start( dumpd(self), input, @@ -352,6 +358,8 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Iterator of Input values into an Iterator of Output values, with callbacks. Use this to implement `stream()` or `transform()` in Runnable subclasses.""" + from langchain.callbacks.manager import CallbackManager + # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = tee(input, 2) # Start the input iterator to ensure the input runnable starts before this one @@ -360,8 +368,12 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) run_manager = callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -432,6 +444,8 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" + from langchain.callbacks.manager import AsyncCallbackManager + # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = atee(input, 2) # Start the input iterator to ensure the input runnable starts before this one @@ -440,8 +454,12 @@ class Runnable(Generic[Input, Output], ABC): final_output: Optional[Output] = None final_output_supported = True - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -517,9 +535,19 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): yield from self.fallbacks def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + from langchain.callbacks.manager import CallbackManager + # setup callbacks - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) # start the root run run_manager = callback_manager.on_chain_start(dumpd(self), input) first_error = None @@ -549,9 +577,19 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: + from langchain.callbacks.manager import AsyncCallbackManager + # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) # start the root run run_manager = await callback_manager.on_chain_start(dumpd(self), input) @@ -770,9 +808,19 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + from langchain.callbacks.manager import CallbackManager + # setup callbacks - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) # start the root run run_manager = callback_manager.on_chain_start(dumpd(self), input) @@ -798,9 +846,19 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: + from langchain.callbacks.manager import AsyncCallbackManager + # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) # start the root run run_manager = await callback_manager.on_chain_start(dumpd(self), input) @@ -935,9 +993,19 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: + from langchain.callbacks.manager import CallbackManager + # setup callbacks - config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) + config = config or {} + callback_manager = CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) # start the root run run_manager = callback_manager.on_chain_start(dumpd(self), input) @@ -1001,9 +1069,19 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: + from langchain.callbacks.manager import AsyncCallbackManager + # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) # start the root run run_manager = await callback_manager.on_chain_start(dumpd(self), input) @@ -1115,7 +1193,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): from langchain.callbacks.manager import CallbackManager # setup callbacks - config = ensure_config(config) + config = config or {} callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), local_callbacks=None, @@ -1138,7 +1216,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): step.invoke, input, # mark each step as a child run - patch_config(deepcopy(config), run_manager.get_child()), + patch_config(config, run_manager.get_child()), ) for step in steps.values() ] @@ -1157,9 +1235,19 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Dict[str, Any]: + from langchain.callbacks.manager import AsyncCallbackManager + # setup callbacks - config = ensure_config(config) - callback_manager = get_async_callback_manager_for_config(config) + config = config or {} + callback_manager = AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + local_callbacks=None, + verbose=False, + inheritable_tags=config.get("tags"), + local_tags=None, + inheritable_metadata=config.get("metadata"), + local_metadata=None, + ) # start the root run run_manager = await callback_manager.on_chain_start(dumpd(self), input) @@ -1452,6 +1540,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): yield item +def patch_config( + config: RunnableConfig, callback_manager: BaseCallbackManager +) -> RunnableConfig: + config = config.copy() + config["callbacks"] = callback_manager + return config + + def coerce_to_runnable( thing: Union[ Runnable[Input, Output], diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 00408b7ee6c..715b79fd9f6 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Dict, List, TypedDict -from langchain.callbacks.base import BaseCallbackManager, Callbacks -from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager +from langchain.callbacks.base import Callbacks class RunnableConfig(TypedDict, total=False): @@ -26,42 +25,3 @@ class RunnableConfig(TypedDict, total=False): Callbacks for this call and any sub-calls (eg. a Chain calling an LLM). Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. """ - - _locals: Dict[str, Any] - """ - Local variables - """ - - -def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: - empty = RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) - if config is not None: - empty.update(config) - return empty - - -def patch_config( - config: RunnableConfig, - callbacks: BaseCallbackManager, -) -> RunnableConfig: - config = config.copy() - config["callbacks"] = callbacks - return config - - -def get_callback_manager_for_config(config: RunnableConfig) -> CallbackManager: - return CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) - - -def get_async_callback_manager_for_config( - config: RunnableConfig, -) -> AsyncCallbackManager: - return AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index d5d7c152c15..420b13fe802 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -47,11 +47,10 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]): ) -> Iterator[Input]: return self._transform_stream_with_config(input, identity, config) - async def atransform( + def atransform( self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> AsyncIterator[Input]: - async for chunk in self._atransform_stream_with_config(input, identity, config): - yield chunk + return self._atransform_stream_with_config(input, identity, config) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr similarity index 99% rename from libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr rename to libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr index 11e554a9cc4..321edb7a294 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/__snapshots__/test_runnable.ambr @@ -1352,7 +1352,6 @@ "lc": 1, "type": "not_implemented", "id": [ - "runnable", "test_runnable", "FakeRetriever" ] diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__init__.py b/libs/langchain/tests/unit_tests/schema/runnable/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py deleted file mode 100644 index ee07c0cfc6e..00000000000 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Any, Callable, Type - -import pytest - -from langchain import PromptTemplate -from langchain.llms import FakeListLLM -from langchain.schema.runnable import ( - GetLocalVar, - PutLocalVar, - RunnablePassthrough, - RunnableSequence, -) - - -@pytest.mark.parametrize( - ("method", "input", "output"), - [ - (lambda r, x: r.invoke(x), "foo", "foo"), - (lambda r, x: r.batch(x), ["foo", "bar"], ["foo", "bar"]), - (lambda r, x: list(r.stream(x))[0], "foo", "foo"), - ], -) -def test_put_get(method: Callable, input: Any, output: Any) -> None: - runnable = PutLocalVar("input") | GetLocalVar("input") - assert method(runnable, input) == output - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("method", "input", "output"), - [ - (lambda r, x: r.ainvoke(x), "foo", "foo"), - (lambda r, x: r.abatch(x), ["foo", "bar"], ["foo", "bar"]), - ], -) -async def test_put_get_async(method: Callable, input: Any, output: Any) -> None: - runnable = PutLocalVar("input") | GetLocalVar("input") - assert await method(runnable, input) == output - - -@pytest.mark.parametrize( - ("runnable", "error"), - [ - (PutLocalVar("input"), ValueError), - (GetLocalVar("input"), ValueError), - (PutLocalVar("input") | GetLocalVar("missing"), KeyError), - ], -) -def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> None: - with pytest.raises(error): - runnable.invoke("foo") - - -def test_get_in_map() -> None: - runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")} - assert runnable.invoke("foo") == {"bar": "foo"} - - -def test_put_in_map() -> None: - runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input") - with pytest.raises(KeyError): - runnable.invoke("foo") - - -@pytest.mark.parametrize( - "runnable", - [ - PutLocalVar("input") | GetLocalVar("input", passthrough_key="output"), - ( - PutLocalVar("input") - | {"input": RunnablePassthrough()} - | PromptTemplate.from_template("say {input}") - | FakeListLLM(responses=["hello"]) - | GetLocalVar("input", passthrough_key="output") - ), - ], -) -@pytest.mark.parametrize( - ("method", "input", "output"), - [ - (lambda r, x: r.invoke(x), "hello", {"input": "hello", "output": "hello"}), - (lambda r, x: r.batch(x), ["hello"], [{"input": "hello", "output": "hello"}]), - ( - lambda r, x: list(r.stream(x))[0], - "hello", - {"input": "hello", "output": "hello"}, - ), - ], -) -def test_put_get_sequence( - runnable: RunnableSequence, method: Callable, input: Any, output: Any -) -> None: - assert method(runnable, input) == output diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/test_runnable.py similarity index 97% rename from libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py rename to libs/langchain/tests/unit_tests/schema/test_runnable.py index 80d63c69123..352ddea4011 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/test_runnable.py @@ -132,24 +132,15 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] ) == [5, 7] assert spy.call_args_list == [ - mocker.call( - "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) - ), - mocker.call( - "wooorld", - dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), - ), + mocker.call("hello", dict(tags=["a-tag"])), + mocker.call("wooorld", dict(metadata={"key": "value"})), ] spy.reset_mock() assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7] assert spy.call_args_list == [ - mocker.call( - "hello", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) - ), - mocker.call( - "wooorld", dict(tags=["a-tag"], metadata={}, callbacks=None, _locals={}) - ), + mocker.call("hello", dict(tags=["a-tag"])), + mocker.call("wooorld", dict(tags=["a-tag"])), ] spy.reset_mock() @@ -170,14 +161,8 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: 7, ] assert spy.call_args_list == [ - mocker.call( - "hello", - dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), - ), - mocker.call( - "wooorld", - dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), - ), + mocker.call("hello", dict(metadata={"key": "value"})), + mocker.call("wooorld", dict(metadata={"key": "value"})), ]