diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index d3bb74547b0..06dfbe90497 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -238,31 +238,40 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede ) -> list[Union[Output, Exception]]: results_map: dict[int, Output] = {} - def pending(iterable: list[U]) -> list[U]: - return [item for idx, item in enumerate(iterable) if idx not in results_map] - not_set: list[Output] = [] result = not_set try: for attempt in self._sync_retrying(): with attempt: - # Get the results of the inputs that have not succeeded yet. + # Retry for inputs that have not yet succeeded + # Determine which original indices remain. + remaining_indices = [ + i for i in range(len(inputs)) if i not in results_map + ] + if not remaining_indices: + break + pending_inputs = [inputs[i] for i in remaining_indices] + pending_configs = [config[i] for i in remaining_indices] + pending_run_managers = [run_manager[i] for i in remaining_indices] + # Invoke underlying batch only on remaining elements. result = super().batch( - pending(inputs), + pending_inputs, self._patch_config_list( - pending(config), pending(run_manager), attempt.retry_state + pending_configs, pending_run_managers, attempt.retry_state ), return_exceptions=True, **kwargs, ) - # Register the results of the inputs that have succeeded. + # Register the results of the inputs that have succeeded, mapping + # back to their original indices. first_exception = None - for i, r in enumerate(result): + for offset, r in enumerate(result): if isinstance(r, Exception): if not first_exception: first_exception = r continue - results_map[i] = r + orig_idx = remaining_indices[offset] + results_map[orig_idx] = r # If any exception occurred, raise it, to retry the failed ones if first_exception: raise first_exception @@ -305,31 +314,39 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede ) -> list[Union[Output, Exception]]: results_map: dict[int, Output] = {} - def pending(iterable: list[U]) -> list[U]: - return [item for idx, item in enumerate(iterable) if idx not in results_map] - not_set: list[Output] = [] result = not_set try: async for attempt in self._async_retrying(): with attempt: - # Get the results of the inputs that have not succeeded yet. + # Retry for inputs that have not yet succeeded + # Determine which original indices remain. + remaining_indices = [ + i for i in range(len(inputs)) if i not in results_map + ] + if not remaining_indices: + break + pending_inputs = [inputs[i] for i in remaining_indices] + pending_configs = [config[i] for i in remaining_indices] + pending_run_managers = [run_manager[i] for i in remaining_indices] result = await super().abatch( - pending(inputs), + pending_inputs, self._patch_config_list( - pending(config), pending(run_manager), attempt.retry_state + pending_configs, pending_run_managers, attempt.retry_state ), return_exceptions=True, **kwargs, ) - # Register the results of the inputs that have succeeded. + # Register the results of the inputs that have succeeded, mapping + # back to their original indices. first_exception = None - for i, r in enumerate(result): + for offset, r in enumerate(result): if isinstance(r, Exception): if not first_exception: first_exception = r continue - results_map[i] = r + orig_idx = remaining_indices[offset] + results_map[orig_idx] = r # If any exception occurred, raise it, to retry the failed ones if first_exception: raise first_exception diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 36ef2f80834..ff637b09cc7 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3919,6 +3919,58 @@ def test_retrying(mocker: MockerFixture) -> None: lambda_mock.reset_mock() +def test_retry_batch_preserves_order() -> None: + """Regression test: batch with retry should preserve input order. + + The previous implementation stored successful results in a map keyed by the + index within the *pending* (filtered) list rather than the original input + index, causing collisions after retries. This produced duplicated outputs + and dropped earlier successes (e.g. [0,1,2] -> [1,1,2]). + """ + # Fail only the middle element on the first attempt to trigger the bug. + first_fail: set[int] = {1} + + def sometimes_fail(x: int) -> int: # pragma: no cover - trivial + if x in first_fail: + first_fail.remove(x) + msg = "fail once" + raise ValueError(msg) + return x + + runnable = RunnableLambda(sometimes_fail) + + results = runnable.with_retry( + stop_after_attempt=2, + wait_exponential_jitter=False, + retry_if_exception_type=(ValueError,), + ).batch([0, 1, 2]) + + # Expect exact ordering preserved. + assert results == [0, 1, 2] + + +async def test_async_retry_batch_preserves_order() -> None: + """Async variant of order preservation regression test.""" + first_fail: set[int] = {1} + + def sometimes_fail(x: int) -> int: # pragma: no cover - trivial + if x in first_fail: + first_fail.remove(x) + msg = "fail once" + raise ValueError(msg) + return x + + runnable = RunnableLambda(sometimes_fail) + + results = await runnable.with_retry( + stop_after_attempt=2, + wait_exponential_jitter=False, + retry_if_exception_type=(ValueError,), + ).abatch([0, 1, 2]) + + assert results == [0, 1, 2] + + async def test_async_retrying(mocker: MockerFixture) -> None: def _lambda(x: int) -> Union[int, Runnable]: if x == 1: