mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-20 01:54:14 +00:00
fix(core): preserve ordering in RunnableRetry batch/abatch results (#32526)
Description: Fixes a bug in RunnableRetry where .batch / .abatch could return misordered outputs (e.g. inputs [0,1,2] yielding [1,1,2]) when some items succeeded on an earlier attempt and others were retried. Root cause: successful results were stored keyed by the index within the shrinking “pending” subset rather than the original input index, causing collisions and reordered/duplicated outputs after retries. Fix updates _batch and _abatch to: - Track remaining original indices explicitly. - Call underlying batch/abatch only on remaining inputs. - Map results back to original indices. - Preserve final ordering by reconstructing outputs in original positional order. Issue: Fixes #21326 Tests: - Added regression tests: test_retry_batch_preserves_order and test_async_retry_batch_preserves_order asserting correct ordering after a single controlled failure + retry. - Existing retry tests still pass. Dependencies: - None added or changed. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -238,31 +238,40 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede
|
||||
) -> list[Union[Output, Exception]]:
|
||||
results_map: dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: list[U]) -> list[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: list[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
# Retry for inputs that have not yet succeeded
|
||||
# Determine which original indices remain.
|
||||
remaining_indices = [
|
||||
i for i in range(len(inputs)) if i not in results_map
|
||||
]
|
||||
if not remaining_indices:
|
||||
break
|
||||
pending_inputs = [inputs[i] for i in remaining_indices]
|
||||
pending_configs = [config[i] for i in remaining_indices]
|
||||
pending_run_managers = [run_manager[i] for i in remaining_indices]
|
||||
# Invoke underlying batch only on remaining elements.
|
||||
result = super().batch(
|
||||
pending(inputs),
|
||||
pending_inputs,
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
pending_configs, pending_run_managers, attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
# Register the results of the inputs that have succeeded, mapping
|
||||
# back to their original indices.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
for offset, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
orig_idx = remaining_indices[offset]
|
||||
results_map[orig_idx] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
@@ -305,31 +314,39 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede
|
||||
) -> list[Union[Output, Exception]]:
|
||||
results_map: dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: list[U]) -> list[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: list[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
# Retry for inputs that have not yet succeeded
|
||||
# Determine which original indices remain.
|
||||
remaining_indices = [
|
||||
i for i in range(len(inputs)) if i not in results_map
|
||||
]
|
||||
if not remaining_indices:
|
||||
break
|
||||
pending_inputs = [inputs[i] for i in remaining_indices]
|
||||
pending_configs = [config[i] for i in remaining_indices]
|
||||
pending_run_managers = [run_manager[i] for i in remaining_indices]
|
||||
result = await super().abatch(
|
||||
pending(inputs),
|
||||
pending_inputs,
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
pending_configs, pending_run_managers, attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
# Register the results of the inputs that have succeeded, mapping
|
||||
# back to their original indices.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
for offset, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
orig_idx = remaining_indices[offset]
|
||||
results_map[orig_idx] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
|
@@ -3919,6 +3919,58 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
lambda_mock.reset_mock()
|
||||
|
||||
|
||||
def test_retry_batch_preserves_order() -> None:
|
||||
"""Regression test: batch with retry should preserve input order.
|
||||
|
||||
The previous implementation stored successful results in a map keyed by the
|
||||
index within the *pending* (filtered) list rather than the original input
|
||||
index, causing collisions after retries. This produced duplicated outputs
|
||||
and dropped earlier successes (e.g. [0,1,2] -> [1,1,2]).
|
||||
"""
|
||||
# Fail only the middle element on the first attempt to trigger the bug.
|
||||
first_fail: set[int] = {1}
|
||||
|
||||
def sometimes_fail(x: int) -> int: # pragma: no cover - trivial
|
||||
if x in first_fail:
|
||||
first_fail.remove(x)
|
||||
msg = "fail once"
|
||||
raise ValueError(msg)
|
||||
return x
|
||||
|
||||
runnable = RunnableLambda(sometimes_fail)
|
||||
|
||||
results = runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).batch([0, 1, 2])
|
||||
|
||||
# Expect exact ordering preserved.
|
||||
assert results == [0, 1, 2]
|
||||
|
||||
|
||||
async def test_async_retry_batch_preserves_order() -> None:
|
||||
"""Async variant of order preservation regression test."""
|
||||
first_fail: set[int] = {1}
|
||||
|
||||
def sometimes_fail(x: int) -> int: # pragma: no cover - trivial
|
||||
if x in first_fail:
|
||||
first_fail.remove(x)
|
||||
msg = "fail once"
|
||||
raise ValueError(msg)
|
||||
return x
|
||||
|
||||
runnable = RunnableLambda(sometimes_fail)
|
||||
|
||||
results = await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
wait_exponential_jitter=False,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).abatch([0, 1, 2])
|
||||
|
||||
assert results == [0, 1, 2]
|
||||
|
||||
|
||||
async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
def _lambda(x: int) -> Union[int, Runnable]:
|
||||
if x == 1:
|
||||
|
Reference in New Issue
Block a user