From 2b7c3c548dbe1c4026246deff88e7d0590df71ba Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 13 Mar 2024 11:18:02 -0700 Subject: [PATCH] core[minor]: Add Runnable.batch_as_completed (#17603) This PR adds `batch as completed` method to the standard Runnable interface. It takes in a list of inputs and yields the corresponding outputs as the inputs are completed. --- libs/core/langchain_core/runnables/base.py | 235 ++++++++++++++++++ .../unit_tests/runnables/test_runnable.py | 67 +++++ 2 files changed, 302 insertions(+) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 8ad86b047d6..8123aba94b2 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -530,6 +530,76 @@ class Runnable(Generic[Input, Output], ABC): with get_executor_for_config(configs[0]) as executor: return cast(List[Output], list(executor.map(invoke, inputs, configs))) + @overload + def batch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: Literal[False] = False, + **kwargs: Any, + ) -> Iterator[Tuple[int, Output]]: + ... + + @overload + def batch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: Literal[True], + **kwargs: Any, + ) -> Iterator[Tuple[int, Union[Output, Exception]]]: + ... + + def batch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> Iterator[Tuple[int, Union[Output, Exception]]]: + """Run invoke in parallel on a list of inputs, + yielding results as they complete.""" + + if not inputs: + return + + configs = get_config_list(config, len(inputs)) + + def invoke( + i: int, input: Input, config: RunnableConfig + ) -> Tuple[int, Union[Output, Exception]]: + if return_exceptions: + try: + out: Union[Output, Exception] = self.invoke(input, config, **kwargs) + except Exception as e: + out = e + else: + out = self.invoke(input, config, **kwargs) + + return (i, out) + + if len(inputs) == 1: + yield invoke(0, inputs[0], configs[0]) + return + + with get_executor_for_config(configs[0]) as executor: + futures = { + executor.submit(invoke, i, input, config) + for i, (input, config) in enumerate(zip(inputs, configs)) + } + + try: + while futures: + done, futures = wait(futures, return_when=FIRST_COMPLETED) + while done: + yield done.pop().result() + finally: + for future in futures: + future.cancel() + async def abatch( self, inputs: List[Input], @@ -564,6 +634,64 @@ class Runnable(Generic[Input, Output], ABC): coros = map(ainvoke, inputs, configs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) + @overload + def abatch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: Literal[False] = False, + **kwargs: Optional[Any], + ) -> AsyncIterator[Tuple[int, Output]]: + ... + + @overload + def abatch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: Literal[True], + **kwargs: Optional[Any], + ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: + ... + + async def abatch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: + """Run ainvoke in parallel on a list of inputs, + yielding results as they complete.""" + + if not inputs: + return + + configs = get_config_list(config, len(inputs)) + + async def ainvoke( + i: int, input: Input, config: RunnableConfig + ) -> Tuple[int, Union[Output, Exception]]: + if return_exceptions: + try: + out: Union[Output, Exception] = await self.ainvoke( + input, config, **kwargs + ) + except Exception as e: + out = e + else: + out = await self.ainvoke(input, config, **kwargs) + + return (i, out) + + coros = map(ainvoke, range(len(inputs)), inputs, configs) + + for coro in asyncio.as_completed(coros): + yield await coro + def stream( self, input: Input, @@ -4149,6 +4277,113 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): **{**self.kwargs, **kwargs}, ) + @overload + def batch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: Literal[False] = False, + **kwargs: Any, + ) -> Iterator[Tuple[int, Output]]: + ... + + @overload + def batch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: Literal[True], + **kwargs: Any, + ) -> Iterator[Tuple[int, Union[Output, Exception]]]: + ... + + def batch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> Iterator[Tuple[int, Union[Output, Exception]]]: + if isinstance(config, list): + configs = cast( + List[RunnableConfig], + [self._merge_configs(conf) for conf in config], + ) + else: + configs = [self._merge_configs(config) for _ in range(len(inputs))] + # lol mypy + if return_exceptions: + yield from self.bound.batch_as_completed( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ) + else: + yield from self.bound.batch_as_completed( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ) + + @overload + def abatch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: Literal[False] = False, + **kwargs: Optional[Any], + ) -> AsyncIterator[Tuple[int, Output]]: + ... + + @overload + def abatch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: Literal[True], + **kwargs: Optional[Any], + ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: + ... + + async def abatch_as_completed( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Optional[Any], + ) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]: + if isinstance(config, list): + configs = cast( + List[RunnableConfig], + [self._merge_configs(conf) for conf in config], + ) + else: + configs = [self._merge_configs(config) for _ in range(len(inputs))] + if return_exceptions: + async for item in self.bound.abatch_as_completed( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ): + yield item + else: + async for item in self.bound.abatch_as_completed( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ): + yield item + def stream( self, input: Input, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 6fd34c81fcf..fc63abd9af7 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1428,6 +1428,30 @@ async def test_with_config(mocker: MockerFixture) -> None: spy.reset_mock() + assert sorted( + c + for c in fake.with_config(recursion_limit=5).batch_as_completed( + ["hello", "wooorld"], + [dict(tags=["a-tag"]), dict(metadata={"key": "value"})], + ) + ) == [(0, 5), (1, 7)] + + assert len(spy.call_args_list) == 2 + for i, call in enumerate( + sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1) + ): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + if i == 0: + assert call.args[1].get("recursion_limit") == 5 + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {} + else: + assert call.args[1].get("recursion_limit") == 5 + assert call.args[1].get("tags") == [] + assert call.args[1].get("metadata") == {"key": "value"} + + spy.reset_mock() + assert fake.with_config(metadata={"a": "b"}).batch( ["hello", "wooorld"], dict(tags=["a-tag"]) ) == [5, 7] @@ -1438,6 +1462,15 @@ async def test_with_config(mocker: MockerFixture) -> None: assert call.args[1].get("metadata") == {"a": "b"} spy.reset_mock() + assert sorted( + c for c in fake.batch_as_completed(["hello", "wooorld"], dict(tags=["a-tag"])) + ) == [(0, 5), (1, 7)] + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + assert call.args[1].get("tags") == ["a-tag"] + spy.reset_mock() + handler = ConsoleCallbackHandler() assert ( await fake.with_config(metadata={"a": "b"}).ainvoke( @@ -1484,6 +1517,40 @@ async def test_with_config(mocker: MockerFixture) -> None: ), ), ] + spy.reset_mock() + + assert sorted( + [ + c + async for c in fake.with_config( + recursion_limit=5, tags=["c"] + ).abatch_as_completed(["hello", "wooorld"], dict(metadata={"key": "value"})) + ] + ) == [ + (0, 5), + (1, 7), + ] + assert len(spy.call_args_list) == 2 + first_call = next(call for call in spy.call_args_list if call.args[0] == "hello") + assert first_call == mocker.call( + "hello", + dict( + metadata={"key": "value"}, + tags=["c"], + callbacks=None, + recursion_limit=5, + ), + ) + second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld") + assert second_call == mocker.call( + "wooorld", + dict( + metadata={"key": "value"}, + tags=["c"], + callbacks=None, + recursion_limit=5, + ), + ) async def test_default_method_implementations(mocker: MockerFixture) -> None: