diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index bdbd7fc6999..98f7d758635 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -210,7 +210,20 @@ class Runnable(Generic[Input, Output], ABC): """ Bind arguments to a Runnable, returning a new Runnable. """ - return RunnableBinding(bound=self, kwargs=kwargs) + return RunnableBinding(bound=self, kwargs=kwargs, config={}) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + """ + Bind config to a Runnable, returning a new Runnable. + """ + return RunnableBinding( + bound=self, config={**(config or {}), **kwargs}, kwargs={} + ) def map(self) -> Runnable[List[Input], List[Output]]: """ @@ -1479,6 +1492,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): kwargs: Mapping[str, Any] + config: Mapping[str, Any] = Field(default_factory=dict) + class Config: arbitrary_types_allowed = True @@ -1491,7 +1506,21 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): return self.__class__.__module__.split(".")[:-1] def bind(self, **kwargs: Any) -> Runnable[Input, Output]: - return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) + return self.__class__( + bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} + ) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config={**self.config, **(config or {}), **kwargs}, + ) def invoke( self, @@ -1499,7 +1528,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return self.bound.invoke(input, config, **{**self.kwargs, **kwargs}) + return self.bound.invoke( + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + ) async def ainvoke( self, @@ -1507,7 +1538,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs}) + return await self.bound.ainvoke( + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + ) def batch( self, @@ -1515,7 +1548,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs}) + configs = ( + [{**self.config, **(conf or {})} for conf in config] + if isinstance(config, list) + else [ + patch_config({**self.config, **(config or {})}, deep_copy_locals=True) + for _ in range(len(inputs)) + ] + ) + return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs}) async def abatch( self, @@ -1523,7 +1564,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs}) + configs = ( + [{**self.config, **(conf or {})} for conf in config] + if isinstance(config, list) + else [ + patch_config({**self.config, **(config or {})}, deep_copy_locals=True) + for _ in range(len(inputs)) + ] + ) + return await self.bound.abatch( + inputs, + [{**self.config, **(conf or {})} for conf in configs], + **{**self.kwargs, **kwargs}, + ) def stream( self, @@ -1531,7 +1584,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: - yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs}) + yield from self.bound.stream( + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + ) async def astream( self, @@ -1540,7 +1595,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Optional[Any], ) -> AsyncIterator[Output]: async for item in self.bound.astream( - input, config, **{**self.kwargs, **kwargs} + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} ): yield item @@ -1550,7 +1605,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Iterator[Output]: - yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs}) + yield from self.bound.transform( + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} + ) async def atransform( self, @@ -1559,11 +1616,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Any, ) -> AsyncIterator[Output]: async for item in self.bound.atransform( - input, config, **{**self.kwargs, **kwargs} + input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs} ): yield item +RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig) + + def coerce_to_runnable( thing: Union[ Runnable[Input, Output], diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index b97d904414d..10cf15c0d88 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -3,7 +3,8 @@ from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional +from typing_extensions import TypedDict if TYPE_CHECKING: from langchain.callbacks.base import BaseCallbackManager, Callbacks @@ -48,7 +49,7 @@ class RunnableConfig(TypedDict, total=False): """ -def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: +def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: empty = RunnableConfig( tags=[], metadata={}, diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index c48d4edbd41..fcb621fe8cf 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -2081,7 +2081,8 @@ "stop": [ "Thought:" ] - } + }, + "config": {} } }, "llm": { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index f2447533107..32f488e10c9 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -112,6 +112,104 @@ class FakeRetriever(BaseRetriever): return [Document(page_content="foo"), Document(page_content="bar")] +@pytest.mark.asyncio +async def test_with_config(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") + + assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5 + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"])), + ] + spy.reset_mock() + + assert [ + *fake.with_config(tags=["a-tag"]).stream( + "hello", dict(metadata={"key": "value"}) + ) + ] == [5] + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})), + ] + spy.reset_mock() + + assert fake.with_config(recursion_limit=5).batch( + ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] + ) == [5, 7] + + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + if i == 0: + assert call.args[1].get("recursion_limit") == 5 + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {} + else: + assert call.args[1].get("recursion_limit") == 5 + assert call.args[1].get("tags") == [] + assert call.args[1].get("metadata") == {"key": "value"} + + spy.reset_mock() + + assert fake.with_config(metadata={"a": "b"}).batch( + ["hello", "wooorld"], dict(tags=["a-tag"]) + ) == [5, 7] + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {"a": "b"} + spy.reset_mock() + + assert ( + await fake.with_config(metadata={"a": "b"}).ainvoke( + "hello", config={"callbacks": []} + ) + == 5 + ) + assert spy.call_args_list == [ + mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})), + ] + spy.reset_mock() + + assert [ + part async for part in fake.with_config(metadata={"a": "b"}).astream("hello") + ] == [5] + assert spy.call_args_list == [ + mocker.call("hello", dict(metadata={"a": "b"})), + ] + spy.reset_mock() + + assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch( + ["hello", "wooorld"], dict(metadata={"key": "value"}) + ) == [ + 5, + 7, + ] + assert spy.call_args_list == [ + mocker.call( + "hello", + dict( + metadata={"key": "value"}, + tags=["c"], + callbacks=None, + _locals={}, + recursion_limit=5, + ), + ), + mocker.call( + "wooorld", + dict( + metadata={"key": "value"}, + tags=["c"], + callbacks=None, + _locals={}, + recursion_limit=5, + ), + ), + ] + + @pytest.mark.asyncio async def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() @@ -1125,6 +1223,14 @@ async def test_map_astream_iterator_input() -> None: assert final_value.get("passthrough") == llm_res +def test_with_config_with_config() -> None: + llm = FakeListLLM(responses=["i'm a textbot"]) + + assert dumpd( + llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"]) + ) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]})) + + def test_bind_bind() -> None: llm = FakeListLLM(responses=["i'm a textbot"])