mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
core[minor]: Add Runnable.batch_as_completed (#17603)
This PR adds `batch as completed` method to the standard Runnable interface. It takes in a list of inputs and yields the corresponding outputs as the inputs are completed.
This commit is contained in:
parent
71d0981f18
commit
2b7c3c548d
@ -530,6 +530,76 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
with get_executor_for_config(configs[0]) as executor:
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
|
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def batch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: Literal[False] = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[Tuple[int, Output]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def batch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: Literal[True],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def batch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
|
||||||
|
"""Run invoke in parallel on a list of inputs,
|
||||||
|
yielding results as they complete."""
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
i: int, input: Input, config: RunnableConfig
|
||||||
|
) -> Tuple[int, Union[Output, Exception]]:
|
||||||
|
if return_exceptions:
|
||||||
|
try:
|
||||||
|
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
out = e
|
||||||
|
else:
|
||||||
|
out = self.invoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
return (i, out)
|
||||||
|
|
||||||
|
if len(inputs) == 1:
|
||||||
|
yield invoke(0, inputs[0], configs[0])
|
||||||
|
return
|
||||||
|
|
||||||
|
with get_executor_for_config(configs[0]) as executor:
|
||||||
|
futures = {
|
||||||
|
executor.submit(invoke, i, input, config)
|
||||||
|
for i, (input, config) in enumerate(zip(inputs, configs))
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
while futures:
|
||||||
|
done, futures = wait(futures, return_when=FIRST_COMPLETED)
|
||||||
|
while done:
|
||||||
|
yield done.pop().result()
|
||||||
|
finally:
|
||||||
|
for future in futures:
|
||||||
|
future.cancel()
|
||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[Input],
|
inputs: List[Input],
|
||||||
@ -564,6 +634,64 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
coros = map(ainvoke, inputs, configs)
|
coros = map(ainvoke, inputs, configs)
|
||||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def abatch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: Literal[False] = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Tuple[int, Output]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def abatch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: Literal[True],
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def abatch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
|
||||||
|
"""Run ainvoke in parallel on a list of inputs,
|
||||||
|
yielding results as they complete."""
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
i: int, input: Input, config: RunnableConfig
|
||||||
|
) -> Tuple[int, Union[Output, Exception]]:
|
||||||
|
if return_exceptions:
|
||||||
|
try:
|
||||||
|
out: Union[Output, Exception] = await self.ainvoke(
|
||||||
|
input, config, **kwargs
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
out = e
|
||||||
|
else:
|
||||||
|
out = await self.ainvoke(input, config, **kwargs)
|
||||||
|
|
||||||
|
return (i, out)
|
||||||
|
|
||||||
|
coros = map(ainvoke, range(len(inputs)), inputs, configs)
|
||||||
|
|
||||||
|
for coro in asyncio.as_completed(coros):
|
||||||
|
yield await coro
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
@ -4149,6 +4277,113 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
|||||||
**{**self.kwargs, **kwargs},
|
**{**self.kwargs, **kwargs},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def batch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: Literal[False] = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[Tuple[int, Output]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def batch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: Literal[True],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def batch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
|
||||||
|
if isinstance(config, list):
|
||||||
|
configs = cast(
|
||||||
|
List[RunnableConfig],
|
||||||
|
[self._merge_configs(conf) for conf in config],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configs = [self._merge_configs(config) for _ in range(len(inputs))]
|
||||||
|
# lol mypy
|
||||||
|
if return_exceptions:
|
||||||
|
yield from self.bound.batch_as_completed(
|
||||||
|
inputs,
|
||||||
|
configs,
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield from self.bound.batch_as_completed(
|
||||||
|
inputs,
|
||||||
|
configs,
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def abatch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: Literal[False] = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Tuple[int, Output]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def abatch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: Literal[True],
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def abatch_as_completed(
|
||||||
|
self,
|
||||||
|
inputs: List[Input],
|
||||||
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||||
|
*,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
|
||||||
|
if isinstance(config, list):
|
||||||
|
configs = cast(
|
||||||
|
List[RunnableConfig],
|
||||||
|
[self._merge_configs(conf) for conf in config],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
configs = [self._merge_configs(config) for _ in range(len(inputs))]
|
||||||
|
if return_exceptions:
|
||||||
|
async for item in self.bound.abatch_as_completed(
|
||||||
|
inputs,
|
||||||
|
configs,
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
):
|
||||||
|
yield item
|
||||||
|
else:
|
||||||
|
async for item in self.bound.abatch_as_completed(
|
||||||
|
inputs,
|
||||||
|
configs,
|
||||||
|
return_exceptions=return_exceptions,
|
||||||
|
**{**self.kwargs, **kwargs},
|
||||||
|
):
|
||||||
|
yield item
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: Input,
|
input: Input,
|
||||||
|
@ -1428,6 +1428,30 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
|
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert sorted(
|
||||||
|
c
|
||||||
|
for c in fake.with_config(recursion_limit=5).batch_as_completed(
|
||||||
|
["hello", "wooorld"],
|
||||||
|
[dict(tags=["a-tag"]), dict(metadata={"key": "value"})],
|
||||||
|
)
|
||||||
|
) == [(0, 5), (1, 7)]
|
||||||
|
|
||||||
|
assert len(spy.call_args_list) == 2
|
||||||
|
for i, call in enumerate(
|
||||||
|
sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)
|
||||||
|
):
|
||||||
|
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||||
|
if i == 0:
|
||||||
|
assert call.args[1].get("recursion_limit") == 5
|
||||||
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
|
assert call.args[1].get("metadata") == {}
|
||||||
|
else:
|
||||||
|
assert call.args[1].get("recursion_limit") == 5
|
||||||
|
assert call.args[1].get("tags") == []
|
||||||
|
assert call.args[1].get("metadata") == {"key": "value"}
|
||||||
|
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
assert fake.with_config(metadata={"a": "b"}).batch(
|
assert fake.with_config(metadata={"a": "b"}).batch(
|
||||||
["hello", "wooorld"], dict(tags=["a-tag"])
|
["hello", "wooorld"], dict(tags=["a-tag"])
|
||||||
) == [5, 7]
|
) == [5, 7]
|
||||||
@ -1438,6 +1462,15 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
assert call.args[1].get("metadata") == {"a": "b"}
|
assert call.args[1].get("metadata") == {"a": "b"}
|
||||||
spy.reset_mock()
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert sorted(
|
||||||
|
c for c in fake.batch_as_completed(["hello", "wooorld"], dict(tags=["a-tag"]))
|
||||||
|
) == [(0, 5), (1, 7)]
|
||||||
|
assert len(spy.call_args_list) == 2
|
||||||
|
for i, call in enumerate(spy.call_args_list):
|
||||||
|
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||||
|
assert call.args[1].get("tags") == ["a-tag"]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
handler = ConsoleCallbackHandler()
|
handler = ConsoleCallbackHandler()
|
||||||
assert (
|
assert (
|
||||||
await fake.with_config(metadata={"a": "b"}).ainvoke(
|
await fake.with_config(metadata={"a": "b"}).ainvoke(
|
||||||
@ -1484,6 +1517,40 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
spy.reset_mock()
|
||||||
|
|
||||||
|
assert sorted(
|
||||||
|
[
|
||||||
|
c
|
||||||
|
async for c in fake.with_config(
|
||||||
|
recursion_limit=5, tags=["c"]
|
||||||
|
).abatch_as_completed(["hello", "wooorld"], dict(metadata={"key": "value"}))
|
||||||
|
]
|
||||||
|
) == [
|
||||||
|
(0, 5),
|
||||||
|
(1, 7),
|
||||||
|
]
|
||||||
|
assert len(spy.call_args_list) == 2
|
||||||
|
first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")
|
||||||
|
assert first_call == mocker.call(
|
||||||
|
"hello",
|
||||||
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=["c"],
|
||||||
|
callbacks=None,
|
||||||
|
recursion_limit=5,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
|
||||||
|
assert second_call == mocker.call(
|
||||||
|
"wooorld",
|
||||||
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=["c"],
|
||||||
|
callbacks=None,
|
||||||
|
recursion_limit=5,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user