fix(core): preserve ordering in RunnableRetry batch/abatch results (#32526)

Description: Fixes a bug in RunnableRetry where .batch / .abatch could
return misordered outputs (e.g. inputs [0,1,2] yielding [1,1,2]) when
some items succeeded on an earlier attempt and others were retried. Root
cause: successful results were stored keyed by the index within the
shrinking “pending” subset rather than the original input index, causing
collisions and reordered/duplicated outputs after retries. Fix updates
_batch and _abatch to:

- Track remaining original indices explicitly.
- Call underlying batch/abatch only on remaining inputs.
- Map results back to original indices.
- Preserve final ordering by reconstructing outputs in original
positional order.

Issue: Fixes #21326

Tests:

- Added regression tests: test_retry_batch_preserves_order and
test_async_retry_batch_preserves_order asserting correct ordering after
a single controlled failure + retry.
- Existing retry tests still pass.

Dependencies:

- None added or changed.

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Vincent Min
2025-09-11 22:18:25 +02:00
committed by GitHub
parent b1f08467cd
commit ff9f17bc66
2 changed files with 87 additions and 18 deletions

View File

@@ -238,31 +238,40 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede
) -> list[Union[Output, Exception]]:
results_map: dict[int, Output] = {}
def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: list[Output] = []
result = not_set
try:
for attempt in self._sync_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
# Retry for inputs that have not yet succeeded
# Determine which original indices remain.
remaining_indices = [
i for i in range(len(inputs)) if i not in results_map
]
if not remaining_indices:
break
pending_inputs = [inputs[i] for i in remaining_indices]
pending_configs = [config[i] for i in remaining_indices]
pending_run_managers = [run_manager[i] for i in remaining_indices]
# Invoke underlying batch only on remaining elements.
result = super().batch(
pending(inputs),
pending_inputs,
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
pending_configs, pending_run_managers, attempt.retry_state
),
return_exceptions=True,
**kwargs,
)
# Register the results of the inputs that have succeeded.
# Register the results of the inputs that have succeeded, mapping
# back to their original indices.
first_exception = None
for i, r in enumerate(result):
for offset, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
orig_idx = remaining_indices[offset]
results_map[orig_idx] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
@@ -305,31 +314,39 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede
) -> list[Union[Output, Exception]]:
results_map: dict[int, Output] = {}
def pending(iterable: list[U]) -> list[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
not_set: list[Output] = []
result = not_set
try:
async for attempt in self._async_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
# Retry for inputs that have not yet succeeded
# Determine which original indices remain.
remaining_indices = [
i for i in range(len(inputs)) if i not in results_map
]
if not remaining_indices:
break
pending_inputs = [inputs[i] for i in remaining_indices]
pending_configs = [config[i] for i in remaining_indices]
pending_run_managers = [run_manager[i] for i in remaining_indices]
result = await super().abatch(
pending(inputs),
pending_inputs,
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
pending_configs, pending_run_managers, attempt.retry_state
),
return_exceptions=True,
**kwargs,
)
# Register the results of the inputs that have succeeded.
# Register the results of the inputs that have succeeded, mapping
# back to their original indices.
first_exception = None
for i, r in enumerate(result):
for offset, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
orig_idx = remaining_indices[offset]
results_map[orig_idx] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception

View File

@@ -3919,6 +3919,58 @@ def test_retrying(mocker: MockerFixture) -> None:
lambda_mock.reset_mock()
def test_retry_batch_preserves_order() -> None:
"""Regression test: batch with retry should preserve input order.
The previous implementation stored successful results in a map keyed by the
index within the *pending* (filtered) list rather than the original input
index, causing collisions after retries. This produced duplicated outputs
and dropped earlier successes (e.g. [0,1,2] -> [1,1,2]).
"""
# Fail only the middle element on the first attempt to trigger the bug.
first_fail: set[int] = {1}
def sometimes_fail(x: int) -> int: # pragma: no cover - trivial
if x in first_fail:
first_fail.remove(x)
msg = "fail once"
raise ValueError(msg)
return x
runnable = RunnableLambda(sometimes_fail)
results = runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).batch([0, 1, 2])
# Expect exact ordering preserved.
assert results == [0, 1, 2]
async def test_async_retry_batch_preserves_order() -> None:
"""Async variant of order preservation regression test."""
first_fail: set[int] = {1}
def sometimes_fail(x: int) -> int: # pragma: no cover - trivial
if x in first_fail:
first_fail.remove(x)
msg = "fail once"
raise ValueError(msg)
return x
runnable = RunnableLambda(sometimes_fail)
results = await runnable.with_retry(
stop_after_attempt=2,
wait_exponential_jitter=False,
retry_if_exception_type=(ValueError,),
).abatch([0, 1, 2])
assert results == [0, 1, 2]
async def test_async_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]:
if x == 1: