mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 17:08:47 +00:00
Add .with_config() method to Runnables which allows binding any config values to a Runnable
This commit is contained in:
parent
324c86acd5
commit
a3c69cf41d
@ -210,7 +210,20 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
Bind arguments to a Runnable, returning a new Runnable.
|
Bind arguments to a Runnable, returning a new Runnable.
|
||||||
"""
|
"""
|
||||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
return RunnableBinding(bound=self, kwargs=kwargs, config={})
|
||||||
|
|
||||||
|
def with_config(
|
||||||
|
self,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
"""
|
||||||
|
Bind config to a Runnable, returning a new Runnable.
|
||||||
|
"""
|
||||||
|
return RunnableBinding(
|
||||||
|
bound=self, config={**(config or {}), **kwargs}, kwargs={}
|
||||||
|
)
|
||||||
|
|
||||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||||
"""
|
"""
|
||||||
@ -1479,6 +1492,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
kwargs: Mapping[str, Any]
|
kwargs: Mapping[str, Any]
|
||||||
|
|
||||||
|
config: Mapping[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@ -1491,7 +1506,21 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
return self.__class__.__module__.split(".")[:-1]
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
|
return self.__class__(
|
||||||
|
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_config(
|
||||||
|
self,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound,
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
config={**self.config, **(config or {}), **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
@ -1499,7 +1528,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return self.bound.invoke(input, config, **{**self.kwargs, **kwargs})
|
return self.bound.invoke(
|
||||||
|
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
@ -1507,7 +1538,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Output:
|
) -> Output:
|
||||||
return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs})
|
return await self.bound.ainvoke(
|
||||||
|
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
@ -1515,7 +1548,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
|
configs = (
|
||||||
|
[{**self.config, **(conf or {})} for conf in config]
|
||||||
|
if isinstance(config, list)
|
||||||
|
else [
|
||||||
|
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
|
||||||
|
for _ in range(len(inputs))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
@ -1523,7 +1564,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
|
configs = (
|
||||||
|
[{**self.config, **(conf or {})} for conf in config]
|
||||||
|
if isinstance(config, list)
|
||||||
|
else [
|
||||||
|
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
|
||||||
|
for _ in range(len(inputs))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return await self.bound.abatch(
|
||||||
|
inputs,
|
||||||
|
[{**self.config, **(conf or {})} for conf in configs],
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
@ -1531,7 +1584,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs})
|
yield from self.bound.stream(
|
||||||
|
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
@ -1540,7 +1595,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.astream(
|
async for item in self.bound.astream(
|
||||||
input, config, **{**self.kwargs, **kwargs}
|
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
@ -1550,7 +1605,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs})
|
yield from self.bound.transform(
|
||||||
|
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||||
|
)
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
@ -1559,11 +1616,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
async for item in self.bound.atransform(
|
async for item in self.bound.atransform(
|
||||||
input, config, **{**self.kwargs, **kwargs}
|
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||||
):
|
):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||||
|
|
||||||
|
|
||||||
def coerce_to_runnable(
|
def coerce_to_runnable(
|
||||||
thing: Union[
|
thing: Union[
|
||||||
Runnable[Input, Output],
|
Runnable[Input, Output],
|
||||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
|||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||||
@ -48,7 +49,7 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||||
empty = RunnableConfig(
|
empty = RunnableConfig(
|
||||||
tags=[],
|
tags=[],
|
||||||
metadata={},
|
metadata={},
|
||||||
|
@ -2081,7 +2081,8 @@
|
|||||||
"stop": [
|
"stop": [
|
||||||
"Thought:"
|
"Thought:"
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
|
"config": {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"llm": {
|
"llm": {
|
||||||
|
@ -112,6 +112,104 @@ class FakeRetriever(BaseRetriever):
|
|||||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_with_config(mocker: MockerFixture) -> None:
|
||||||
|
fake = FakeRunnable()
|
||||||
|
spy = mocker.spy(fake, "invoke")
|
||||||
|
|
||||||
|
assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(tags=["a-tag"])),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert [
|
||||||
|
*fake.with_config(tags=["a-tag"]).stream(
|
||||||
|
"hello", dict(metadata={"key": "value"})
|
||||||
|
)
|
||||||
|
] == [5]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert fake.with_config(recursion_limit=5).batch(
|
||||||
|
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||||
|
) == [5, 7]
|
||||||
|
|
||||||
|
assert len(spy.call_args_list) == 2
|
||||||
|
for i, call in enumerate(spy.call_args_list):
|
||||||
|
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||||
|
if i == 0:
|
||||||
|
assert call.args[1].get("recursion_limit") == 5
|
||||||
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
|
assert call.args[1].get("metadata") == {}
|
||||||
|
else:
|
||||||
|
assert call.args[1].get("recursion_limit") == 5
|
||||||
|
assert call.args[1].get("tags") == []
|
||||||
|
assert call.args[1].get("metadata") == {"key": "value"}
|
||||||
|
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert fake.with_config(metadata={"a": "b"}).batch(
|
||||||
|
["hello", "wooorld"], dict(tags=["a-tag"])
|
||||||
|
) == [5, 7]
|
||||||
|
assert len(spy.call_args_list) == 2
|
||||||
|
for i, call in enumerate(spy.call_args_list):
|
||||||
|
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||||
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
|
assert call.args[1].get("metadata") == {"a": "b"}
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await fake.with_config(metadata={"a": "b"}).ainvoke(
|
||||||
|
"hello", config={"callbacks": []}
|
||||||
|
)
|
||||||
|
== 5
|
||||||
|
)
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert [
|
||||||
|
part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
|
||||||
|
] == [5]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call("hello", dict(metadata={"a": "b"})),
|
||||||
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
|
||||||
|
["hello", "wooorld"], dict(metadata={"key": "value"})
|
||||||
|
) == [
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
]
|
||||||
|
assert spy.call_args_list == [
|
||||||
|
mocker.call(
|
||||||
|
"hello",
|
||||||
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=["c"],
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
recursion_limit=5,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
mocker.call(
|
||||||
|
"wooorld",
|
||||||
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=["c"],
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
recursion_limit=5,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||||
fake = FakeRunnable()
|
fake = FakeRunnable()
|
||||||
@ -1125,6 +1223,14 @@ async def test_map_astream_iterator_input() -> None:
|
|||||||
assert final_value.get("passthrough") == llm_res
|
assert final_value.get("passthrough") == llm_res
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_config_with_config() -> None:
|
||||||
|
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||||
|
|
||||||
|
assert dumpd(
|
||||||
|
llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"])
|
||||||
|
) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]}))
|
||||||
|
|
||||||
|
|
||||||
def test_bind_bind() -> None:
|
def test_bind_bind() -> None:
|
||||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user