Add .with_config() method to Runnables which allows binding any config values to a Runnable

This commit is contained in:
Nuno Campos 2023-08-24 11:53:29 +02:00
parent 324c86acd5
commit a3c69cf41d
4 changed files with 181 additions and 13 deletions

View File

@ -210,7 +210,20 @@ class Runnable(Generic[Input, Output], ABC):
""" """
Bind arguments to a Runnable, returning a new Runnable. Bind arguments to a Runnable, returning a new Runnable.
""" """
return RunnableBinding(bound=self, kwargs=kwargs) return RunnableBinding(bound=self, kwargs=kwargs, config={})
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
"""
Bind config to a Runnable, returning a new Runnable.
"""
return RunnableBinding(
bound=self, config={**(config or {}), **kwargs}, kwargs={}
)
def map(self) -> Runnable[List[Input], List[Output]]: def map(self) -> Runnable[List[Input], List[Output]]:
""" """
@ -1479,6 +1492,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
kwargs: Mapping[str, Any] kwargs: Mapping[str, Any]
config: Mapping[str, Any] = Field(default_factory=dict)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -1491,7 +1506,21 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
return self.__class__.__module__.split(".")[:-1] return self.__class__.__module__.split(".")[:-1]
def bind(self, **kwargs: Any) -> Runnable[Input, Output]: def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
)
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config={**self.config, **(config or {}), **kwargs},
)
def invoke( def invoke(
self, self,
@ -1499,7 +1528,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
return self.bound.invoke(input, config, **{**self.kwargs, **kwargs}) return self.bound.invoke(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
)
async def ainvoke( async def ainvoke(
self, self,
@ -1507,7 +1538,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs}) return await self.bound.ainvoke(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
)
def batch( def batch(
self, self,
@ -1515,7 +1548,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs}) configs = (
[{**self.config, **(conf or {})} for conf in config]
if isinstance(config, list)
else [
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
for _ in range(len(inputs))
]
)
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
async def abatch( async def abatch(
self, self,
@ -1523,7 +1564,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs}) configs = (
[{**self.config, **(conf or {})} for conf in config]
if isinstance(config, list)
else [
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
for _ in range(len(inputs))
]
)
return await self.bound.abatch(
inputs,
[{**self.config, **(conf or {})} for conf in configs],
**{**self.kwargs, **kwargs},
)
def stream( def stream(
self, self,
@ -1531,7 +1584,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs}) yield from self.bound.stream(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
)
async def astream( async def astream(
self, self,
@ -1540,7 +1595,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.astream( async for item in self.bound.astream(
input, config, **{**self.kwargs, **kwargs} input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
): ):
yield item yield item
@ -1550,7 +1605,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Output]: ) -> Iterator[Output]:
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs}) yield from self.bound.transform(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
)
async def atransform( async def atransform(
self, self,
@ -1559,11 +1616,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
async for item in self.bound.atransform( async for item in self.bound.atransform(
input, config, **{**self.kwargs, **kwargs} input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
): ):
yield item yield item
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
def coerce_to_runnable( def coerce_to_runnable(
thing: Union[ thing: Union[
Runnable[Input, Output], Runnable[Input, Output],

View File

@ -3,7 +3,8 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
from typing_extensions import TypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackManager, Callbacks from langchain.callbacks.base import BaseCallbackManager, Callbacks
@ -48,7 +49,7 @@ class RunnableConfig(TypedDict, total=False):
""" """
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
empty = RunnableConfig( empty = RunnableConfig(
tags=[], tags=[],
metadata={}, metadata={},

View File

@ -2081,7 +2081,8 @@
"stop": [ "stop": [
"Thought:" "Thought:"
] ]
} },
"config": {}
} }
}, },
"llm": { "llm": {

View File

@ -112,6 +112,104 @@ class FakeRetriever(BaseRetriever):
return [Document(page_content="foo"), Document(page_content="bar")] return [Document(page_content="foo"), Document(page_content="bar")]
@pytest.mark.asyncio
async def test_with_config(mocker: MockerFixture) -> None:
fake = FakeRunnable()
spy = mocker.spy(fake, "invoke")
assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
]
spy.reset_mock()
assert [
*fake.with_config(tags=["a-tag"]).stream(
"hello", dict(metadata={"key": "value"})
)
] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})),
]
spy.reset_mock()
assert fake.with_config(recursion_limit=5).batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
else:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}
spy.reset_mock()
assert fake.with_config(metadata={"a": "b"}).batch(
["hello", "wooorld"], dict(tags=["a-tag"])
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {"a": "b"}
spy.reset_mock()
assert (
await fake.with_config(metadata={"a": "b"}).ainvoke(
"hello", config={"callbacks": []}
)
== 5
)
assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})),
]
spy.reset_mock()
assert [
part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"a": "b"})),
]
spy.reset_mock()
assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
["hello", "wooorld"], dict(metadata={"key": "value"})
) == [
5,
7,
]
assert spy.call_args_list == [
mocker.call(
"hello",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
_locals={},
recursion_limit=5,
),
),
mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
_locals={},
recursion_limit=5,
),
),
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_default_method_implementations(mocker: MockerFixture) -> None: async def test_default_method_implementations(mocker: MockerFixture) -> None:
fake = FakeRunnable() fake = FakeRunnable()
@ -1125,6 +1223,14 @@ async def test_map_astream_iterator_input() -> None:
assert final_value.get("passthrough") == llm_res assert final_value.get("passthrough") == llm_res
def test_with_config_with_config() -> None:
llm = FakeListLLM(responses=["i'm a textbot"])
assert dumpd(
llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"])
) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]}))
def test_bind_bind() -> None: def test_bind_bind() -> None:
llm = FakeListLLM(responses=["i'm a textbot"]) llm = FakeListLLM(responses=["i'm a textbot"])