From 8e053ac9d22e325f5edf7abdf8fcdfa7eedb9d9b Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 10 Apr 2025 19:18:36 -0400 Subject: [PATCH] core[patch]: support customization of backoff parameters in `with_retries` (#30773) Co-authored-by: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> --- libs/core/langchain_core/runnables/base.py | 15 ++++------ libs/core/langchain_core/runnables/retry.py | 30 ++++++++++++++++--- .../unit_tests/runnables/test_runnable.py | 1 + 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 1c48e9aac4d..68c7e14f1b4 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -94,6 +94,7 @@ if TYPE_CHECKING: from langchain_core.runnables.fallbacks import ( RunnableWithFallbacks as RunnableWithFallbacksT, ) + from langchain_core.runnables.retry import ExponentialJitterParams from langchain_core.runnables.schema import StreamEvent from langchain_core.tools import BaseTool from langchain_core.tracers.log_stream import ( @@ -1742,6 +1743,7 @@ class Runnable(Generic[Input, Output], ABC): *, retry_if_exception_type: tuple[type[BaseException], ...] = (Exception,), wait_exponential_jitter: bool = True, + exponential_jitter_params: Optional[ExponentialJitterParams] = None, stop_after_attempt: int = 3, ) -> Runnable[Input, Output]: """Create a new Runnable that retries the original Runnable on exceptions. @@ -1753,6 +1755,9 @@ class Runnable(Generic[Input, Output], ABC): time between retries. Defaults to True. stop_after_attempt: The maximum number of attempts to make before giving up. Defaults to 3. + exponential_jitter_params: Parameters for + ``tenacity.wait_exponential_jitter``. Namely: ``initial``, ``max``, + ``exp_base``, and ``jitter`` (all float values). Returns: A new Runnable that retries the original Runnable on exceptions. @@ -1786,15 +1791,6 @@ class Runnable(Generic[Input, Output], ABC): assert (count == 2) - - Args: - retry_if_exception_type: A tuple of exception types to retry on - wait_exponential_jitter: Whether to add jitter to the wait time - between retries - stop_after_attempt: The maximum number of attempts to make before giving up - - Returns: - A new Runnable that retries the original Runnable on exceptions. """ from langchain_core.runnables.retry import RunnableRetry @@ -1805,6 +1801,7 @@ class Runnable(Generic[Input, Output], ABC): retry_exception_types=retry_if_exception_type, wait_exponential_jitter=wait_exponential_jitter, max_attempt_number=stop_after_attempt, + exponential_jitter_params=exponential_jitter_params, ) def map(self) -> Runnable[list[Input], list[Output]]: diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index cbbb9da0776..4ef52f8b736 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -18,7 +18,7 @@ from tenacity import ( stop_after_attempt, wait_exponential_jitter, ) -from typing_extensions import override +from typing_extensions import TypedDict, override from langchain_core.runnables.base import Input, Output, RunnableBindingBase from langchain_core.runnables.config import RunnableConfig, patch_config @@ -33,6 +33,19 @@ if TYPE_CHECKING: U = TypeVar("U") +class ExponentialJitterParams(TypedDict, total=False): + """Parameters for ``tenacity.wait_exponential_jitter``.""" + + initial: float + """Initial wait.""" + max: float + """Maximum wait.""" + exp_base: float + """Base for exponential backoff.""" + jitter: float + """Random additional wait sampled from random.uniform(0, jitter).""" + + class RunnableRetry(RunnableBindingBase[Input, Output]): """Retry a Runnable if it fails. @@ -62,6 +75,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): retry_if_exception_type=(ValueError,), # Retry only on ValueError wait_exponential_jitter=True, # Add jitter to the exponential backoff stop_after_attempt=2, # Try twice + exponential_jitter_params={"initial": 2}, # if desired, customize backoff ) # The method invocation above is equivalent to the longer form below: @@ -70,7 +84,8 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): bound=runnable, retry_exception_types=(ValueError,), max_attempt_number=2, - wait_exponential_jitter=True + wait_exponential_jitter=True, + exponential_jitter_params={"initial": 2}, ) This logic can be used to retry any Runnable, including a chain of Runnables, @@ -94,7 +109,7 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): # Bad chain = template | model retryable_chain = chain.with_retry() - """ + """ # noqa: E501 retry_exception_types: tuple[type[BaseException], ...] = (Exception,) """The exception types to retry on. By default all exceptions are retried. @@ -109,6 +124,11 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): wait_exponential_jitter: bool = True """Whether to add jitter to the exponential backoff.""" + exponential_jitter_params: Optional[ExponentialJitterParams] = None + """Parameters for ``tenacity.wait_exponential_jitter``. Namely: ``initial``, + ``max``, ``exp_base``, and ``jitter`` (all float values). + """ + max_attempt_number: int = 3 """The maximum number of attempts to retry the Runnable.""" @@ -120,7 +140,9 @@ class RunnableRetry(RunnableBindingBase[Input, Output]): kwargs["stop"] = stop_after_attempt(self.max_attempt_number) if self.wait_exponential_jitter: - kwargs["wait"] = wait_exponential_jitter() + kwargs["wait"] = wait_exponential_jitter( + **(self.exponential_jitter_params or {}) + ) if self.retry_exception_types: kwargs["retry"] = retry_if_exception_type(self.retry_exception_types) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 1d4c3d95902..9a63d432678 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3883,6 +3883,7 @@ def test_retrying(mocker: MockerFixture) -> None: runnable.with_retry( stop_after_attempt=2, retry_if_exception_type=(ValueError,), + exponential_jitter_params={"initial": 0.1}, ).invoke(1) assert _lambda_mock.call_count == 2 # retried