core[minor]: Add Runnable.batch_as_completed (#17603)

This PR adds `batch as completed` method to the standard Runnable
interface. It takes in a list of inputs and yields the corresponding
outputs as the inputs are completed.
This commit is contained in:
Nuno Campos
2024-03-13 11:18:02 -07:00
committed by GitHub
parent 71d0981f18
commit 2b7c3c548d
2 changed files with 302 additions and 0 deletions

View File

@@ -530,6 +530,76 @@ class Runnable(Generic[Input, Output], ABC):
with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Output]]:
...
@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
...
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
"""Run invoke in parallel on a list of inputs,
yielding results as they complete."""
if not inputs:
return
configs = get_config_list(config, len(inputs))
def invoke(
i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
except Exception as e:
out = e
else:
out = self.invoke(input, config, **kwargs)
return (i, out)
if len(inputs) == 1:
yield invoke(0, inputs[0], configs[0])
return
with get_executor_for_config(configs[0]) as executor:
futures = {
executor.submit(invoke, i, input, config)
for i, (input, config) in enumerate(zip(inputs, configs))
}
try:
while futures:
done, futures = wait(futures, return_when=FIRST_COMPLETED)
while done:
yield done.pop().result()
finally:
for future in futures:
future.cancel()
async def abatch(
self,
inputs: List[Input],
@@ -564,6 +634,64 @@ class Runnable(Generic[Input, Output], ABC):
coros = map(ainvoke, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]:
...
@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
...
async def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
"""Run ainvoke in parallel on a list of inputs,
yielding results as they complete."""
if not inputs:
return
configs = get_config_list(config, len(inputs))
async def ainvoke(
i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = await self.ainvoke(
input, config, **kwargs
)
except Exception as e:
out = e
else:
out = await self.ainvoke(input, config, **kwargs)
return (i, out)
coros = map(ainvoke, range(len(inputs)), inputs, configs)
for coro in asyncio.as_completed(coros):
yield await coro
def stream(
self,
input: Input,
@@ -4149,6 +4277,113 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
**{**self.kwargs, **kwargs},
)
@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Output]]:
...
@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
...
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [self._merge_configs(config) for _ in range(len(inputs))]
# lol mypy
if return_exceptions:
yield from self.bound.batch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)
else:
yield from self.bound.batch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)
@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]:
...
@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
...
async def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [self._merge_configs(config) for _ in range(len(inputs))]
if return_exceptions:
async for item in self.bound.abatch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
):
yield item
else:
async for item in self.bound.abatch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
):
yield item
def stream(
self,
input: Input,