mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 12:01:54 +00:00
fix(core): preserve ordering in RunnableRetry batch/abatch results (#32526)
Description: Fixes a bug in RunnableRetry where .batch / .abatch could return misordered outputs (e.g. inputs [0,1,2] yielding [1,1,2]) when some items succeeded on an earlier attempt and others were retried. Root cause: successful results were stored keyed by the index within the shrinking “pending” subset rather than the original input index, causing collisions and reordered/duplicated outputs after retries. Fix updates _batch and _abatch to: - Track remaining original indices explicitly. - Call underlying batch/abatch only on remaining inputs. - Map results back to original indices. - Preserve final ordering by reconstructing outputs in original positional order. Issue: Fixes #21326 Tests: - Added regression tests: test_retry_batch_preserves_order and test_async_retry_batch_preserves_order asserting correct ordering after a single controlled failure + retry. - Existing retry tests still pass. Dependencies: - None added or changed. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -238,31 +238,40 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede
|
||||
) -> list[Union[Output, Exception]]:
|
||||
results_map: dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: list[U]) -> list[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: list[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
# Retry for inputs that have not yet succeeded
|
||||
# Determine which original indices remain.
|
||||
remaining_indices = [
|
||||
i for i in range(len(inputs)) if i not in results_map
|
||||
]
|
||||
if not remaining_indices:
|
||||
break
|
||||
pending_inputs = [inputs[i] for i in remaining_indices]
|
||||
pending_configs = [config[i] for i in remaining_indices]
|
||||
pending_run_managers = [run_manager[i] for i in remaining_indices]
|
||||
# Invoke underlying batch only on remaining elements.
|
||||
result = super().batch(
|
||||
pending(inputs),
|
||||
pending_inputs,
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
pending_configs, pending_run_managers, attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
# Register the results of the inputs that have succeeded, mapping
|
||||
# back to their original indices.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
for offset, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
orig_idx = remaining_indices[offset]
|
||||
results_map[orig_idx] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
@@ -305,31 +314,39 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-rede
|
||||
) -> list[Union[Output, Exception]]:
|
||||
results_map: dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: list[U]) -> list[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
not_set: list[Output] = []
|
||||
result = not_set
|
||||
try:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
# Retry for inputs that have not yet succeeded
|
||||
# Determine which original indices remain.
|
||||
remaining_indices = [
|
||||
i for i in range(len(inputs)) if i not in results_map
|
||||
]
|
||||
if not remaining_indices:
|
||||
break
|
||||
pending_inputs = [inputs[i] for i in remaining_indices]
|
||||
pending_configs = [config[i] for i in remaining_indices]
|
||||
pending_run_managers = [run_manager[i] for i in remaining_indices]
|
||||
result = await super().abatch(
|
||||
pending(inputs),
|
||||
pending_inputs,
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
pending_configs, pending_run_managers, attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
**kwargs,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
# Register the results of the inputs that have succeeded, mapping
|
||||
# back to their original indices.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
for offset, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
orig_idx = remaining_indices[offset]
|
||||
results_map[orig_idx] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
|
Reference in New Issue
Block a user