core[patch]: support customization of backoff parameters in with_retries (#30773)

Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com>
This commit is contained in:
ccurme 2025-04-10 19:18:36 -04:00 committed by GitHub
parent e981a9810d
commit 8e053ac9d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 13 deletions

View File

@ -94,6 +94,7 @@ if TYPE_CHECKING:
from langchain_core.runnables.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain_core.runnables.retry import ExponentialJitterParams
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers.log_stream import (
@ -1742,6 +1743,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
retry_if_exception_type: tuple[type[BaseException], ...] = (Exception,),
wait_exponential_jitter: bool = True,
exponential_jitter_params: Optional[ExponentialJitterParams] = None,
stop_after_attempt: int = 3,
) -> Runnable[Input, Output]:
"""Create a new Runnable that retries the original Runnable on exceptions.
@ -1753,6 +1755,9 @@ class Runnable(Generic[Input, Output], ABC):
time between retries. Defaults to True.
stop_after_attempt: The maximum number of attempts to make before
giving up. Defaults to 3.
exponential_jitter_params: Parameters for
``tenacity.wait_exponential_jitter``. Namely: ``initial``, ``max``,
``exp_base``, and ``jitter`` (all float values).
Returns:
A new Runnable that retries the original Runnable on exceptions.
@ -1786,15 +1791,6 @@ class Runnable(Generic[Input, Output], ABC):
assert (count == 2)
Args:
retry_if_exception_type: A tuple of exception types to retry on
wait_exponential_jitter: Whether to add jitter to the wait time
between retries
stop_after_attempt: The maximum number of attempts to make before giving up
Returns:
A new Runnable that retries the original Runnable on exceptions.
"""
from langchain_core.runnables.retry import RunnableRetry
@ -1805,6 +1801,7 @@ class Runnable(Generic[Input, Output], ABC):
retry_exception_types=retry_if_exception_type,
wait_exponential_jitter=wait_exponential_jitter,
max_attempt_number=stop_after_attempt,
exponential_jitter_params=exponential_jitter_params,
)
def map(self) -> Runnable[list[Input], list[Output]]:

View File

@ -18,7 +18,7 @@ from tenacity import (
stop_after_attempt,
wait_exponential_jitter,
)
from typing_extensions import override
from typing_extensions import TypedDict, override
from langchain_core.runnables.base import Input, Output, RunnableBindingBase
from langchain_core.runnables.config import RunnableConfig, patch_config
@ -33,6 +33,19 @@ if TYPE_CHECKING:
U = TypeVar("U")
class ExponentialJitterParams(TypedDict, total=False):
"""Parameters for ``tenacity.wait_exponential_jitter``."""
initial: float
"""Initial wait."""
max: float
"""Maximum wait."""
exp_base: float
"""Base for exponential backoff."""
jitter: float
"""Random additional wait sampled from random.uniform(0, jitter)."""
class RunnableRetry(RunnableBindingBase[Input, Output]):
"""Retry a Runnable if it fails.
@ -62,6 +75,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
retry_if_exception_type=(ValueError,), # Retry only on ValueError
wait_exponential_jitter=True, # Add jitter to the exponential backoff
stop_after_attempt=2, # Try twice
exponential_jitter_params={"initial": 2}, # if desired, customize backoff
)
# The method invocation above is equivalent to the longer form below:
@ -70,7 +84,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
bound=runnable,
retry_exception_types=(ValueError,),
max_attempt_number=2,
wait_exponential_jitter=True
wait_exponential_jitter=True,
exponential_jitter_params={"initial": 2},
)
This logic can be used to retry any Runnable, including a chain of Runnables,
@ -94,7 +109,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
# Bad
chain = template | model
retryable_chain = chain.with_retry()
"""
""" # noqa: E501
retry_exception_types: tuple[type[BaseException], ...] = (Exception,)
"""The exception types to retry on. By default all exceptions are retried.
@ -109,6 +124,11 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
wait_exponential_jitter: bool = True
"""Whether to add jitter to the exponential backoff."""
exponential_jitter_params: Optional[ExponentialJitterParams] = None
"""Parameters for ``tenacity.wait_exponential_jitter``. Namely: ``initial``,
``max``, ``exp_base``, and ``jitter`` (all float values).
"""
max_attempt_number: int = 3
"""The maximum number of attempts to retry the Runnable."""
@ -120,7 +140,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
if self.wait_exponential_jitter:
kwargs["wait"] = wait_exponential_jitter()
kwargs["wait"] = wait_exponential_jitter(
**(self.exponential_jitter_params or {})
)
if self.retry_exception_types:
kwargs["retry"] = retry_if_exception_type(self.retry_exception_types)

View File

@ -3883,6 +3883,7 @@ def test_retrying(mocker: MockerFixture) -> None:
runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
exponential_jitter_params={"initial": 0.1},
).invoke(1)
assert _lambda_mock.call_count == 2 # retried