mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 10:09:46 +00:00
Compare commits
1 Commits
mdrxy/post
...
sr/retry
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44dab2bd4a |
4
libs/langchain/uv.lock
generated
4
libs/langchain/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.10.0, <4.0.0"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14' and platform_python_implementation == 'PyPy'",
|
||||
@@ -2478,7 +2478,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.3"
|
||||
version = "1.0.4"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
|
||||
@@ -11,6 +11,7 @@ from .human_in_the_loop import (
|
||||
)
|
||||
from .model_call_limit import ModelCallLimitMiddleware
|
||||
from .model_fallback import ModelFallbackMiddleware
|
||||
from .model_retry import ModelRetryMiddleware
|
||||
from .pii import PIIDetectionError, PIIMiddleware
|
||||
from .shell_tool import (
|
||||
CodexSandboxExecutionPolicy,
|
||||
@@ -57,6 +58,7 @@ __all__ = [
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"ModelRetryMiddleware",
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"RedactionRule",
|
||||
|
||||
271
libs/langchain_v1/langchain/agents/middleware/model_retry.py
Normal file
271
libs/langchain_v1/langchain/agents/middleware/model_retry.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Model retry middleware for agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelCallResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
|
||||
|
||||
class ModelRetryMiddleware(AgentMiddleware):
|
||||
"""Middleware that automatically retries failed model calls with configurable backoff.
|
||||
|
||||
Uses `model.with_fallbacks()` under the hood. Supports retrying on specific exceptions
|
||||
and exponential backoff.
|
||||
|
||||
Examples:
|
||||
Basic usage with default settings (2 retries, exponential backoff):
|
||||
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import ModelRetryMiddleware
|
||||
|
||||
agent = create_agent(model, tools=[search_tool], middleware=[ModelRetryMiddleware()])
|
||||
```
|
||||
|
||||
Retry specific exceptions only:
|
||||
|
||||
```python
|
||||
from requests.exceptions import RequestException, Timeout
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=4,
|
||||
retry_on=(RequestException, Timeout),
|
||||
backoff_factor=1.5,
|
||||
)
|
||||
```
|
||||
|
||||
Custom exception filtering:
|
||||
|
||||
```python
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
# Only retry on 5xx errors
|
||||
if isinstance(exc, HTTPError):
|
||||
return 500 <= exc.status_code < 600
|
||||
return False
|
||||
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
|
||||
Constant backoff (no exponential growth):
|
||||
|
||||
```python
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=5,
|
||||
backoff_factor=0.0, # No exponential growth
|
||||
initial_delay=2.0, # Always wait 2 seconds
|
||||
)
|
||||
```
|
||||
|
||||
Raise exception on failure:
|
||||
|
||||
```python
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=2,
|
||||
on_failure="raise", # Re-raise exception instead of returning default
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retries: int = 2,
|
||||
retry_on: tuple[type[Exception], ...] | Callable[[Exception], bool] = (Exception,),
|
||||
on_failure: Literal["raise"] = "raise",
|
||||
backoff_factor: float = 2.0,
|
||||
initial_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
jitter: bool = True,
|
||||
) -> None:
|
||||
"""Initialize `ModelRetryMiddleware`.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts after the initial call.
|
||||
Default is `2` retries (`3` total attempts). Must be `>= 0`.
|
||||
retry_on: Either a tuple of exception types to retry on, or a callable
|
||||
that takes an exception and returns `True` if it should be retried.
|
||||
|
||||
Default is to retry on all exceptions.
|
||||
on_failure: Behavior when all retries are exhausted.
|
||||
|
||||
Currently only `'raise'` is supported (re-raises the exception).
|
||||
backoff_factor: Multiplier for exponential backoff.
|
||||
|
||||
Each retry waits `initial_delay * (backoff_factor ** retry_number)`
|
||||
seconds.
|
||||
|
||||
Set to `0.0` for constant delay.
|
||||
initial_delay: Initial delay in seconds before first retry.
|
||||
max_delay: Maximum delay in seconds between retries.
|
||||
|
||||
Caps exponential backoff growth.
|
||||
jitter: Whether to add random jitter (`±25%`) to delay to avoid thundering herd.
|
||||
|
||||
Raises:
|
||||
ValueError: If `max_retries < 0` or delays are negative.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Validate parameters
|
||||
if max_retries < 0:
|
||||
msg = "max_retries must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if initial_delay < 0:
|
||||
msg = "initial_delay must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if max_delay < 0:
|
||||
msg = "max_delay must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if backoff_factor < 0:
|
||||
msg = "backoff_factor must be >= 0"
|
||||
raise ValueError(msg)
|
||||
if on_failure != "raise":
|
||||
msg = "Currently only on_failure='raise' is supported"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.max_retries = max_retries
|
||||
self.tools = [] # No additional tools registered by this middleware
|
||||
self.retry_on = retry_on
|
||||
self.on_failure = on_failure
|
||||
self.backoff_factor = backoff_factor
|
||||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.jitter = jitter
|
||||
|
||||
def _should_retry_exception(self, exc: Exception) -> bool:
|
||||
"""Check if the exception should trigger a retry.
|
||||
|
||||
Args:
|
||||
exc: The exception that occurred.
|
||||
|
||||
Returns:
|
||||
`True` if the exception should be retried, `False` otherwise.
|
||||
"""
|
||||
if callable(self.retry_on):
|
||||
return self.retry_on(exc)
|
||||
return isinstance(exc, self.retry_on)
|
||||
|
||||
def _calculate_delay(self, retry_number: int) -> float:
|
||||
"""Calculate delay for the given retry attempt.
|
||||
|
||||
Args:
|
||||
retry_number: The retry attempt number (0-indexed).
|
||||
|
||||
Returns:
|
||||
Delay in seconds before next retry.
|
||||
"""
|
||||
if self.backoff_factor == 0.0:
|
||||
delay = self.initial_delay
|
||||
else:
|
||||
delay = self.initial_delay * (self.backoff_factor**retry_number)
|
||||
|
||||
# Cap at max_delay
|
||||
delay = min(delay, self.max_delay)
|
||||
|
||||
if self.jitter and delay > 0:
|
||||
jitter_amount = delay * 0.25
|
||||
delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
|
||||
# Ensure delay is not negative after jitter
|
||||
delay = max(0, delay)
|
||||
|
||||
return delay
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Intercept model execution and retry on failure.
|
||||
|
||||
Args:
|
||||
request: Model request with model, messages, state, and runtime.
|
||||
handler: Callable to execute the model (can be called multiple times).
|
||||
|
||||
Returns:
|
||||
`ModelResponse` or `AIMessage` (the final result).
|
||||
|
||||
Raises:
|
||||
Exception: If all retries are exhausted and `on_failure='raise'`.
|
||||
"""
|
||||
# Initial attempt + retries
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as exc:
|
||||
# Check if we should retry this exception
|
||||
if not self._should_retry_exception(exc):
|
||||
# Exception is not retryable, raise immediately
|
||||
raise
|
||||
|
||||
# Check if we have more retries left
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff delay
|
||||
delay = self._calculate_delay(attempt)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
# Continue to next retry
|
||||
else:
|
||||
# No more retries, raise the exception
|
||||
raise
|
||||
|
||||
# Unreachable: loop always returns via handler success or raises
|
||||
msg = "Unexpected: retry loop completed without returning"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Intercept and control async model execution with retry logic.
|
||||
|
||||
Args:
|
||||
request: Model request with model, messages, state, and runtime.
|
||||
handler: Async callable to execute the model and returns `ModelResponse`.
|
||||
|
||||
Returns:
|
||||
`ModelResponse` or `AIMessage` (the final result).
|
||||
|
||||
Raises:
|
||||
Exception: If all retries are exhausted and `on_failure='raise'`.
|
||||
"""
|
||||
# Initial attempt + retries
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as exc:
|
||||
# Check if we should retry this exception
|
||||
if not self._should_retry_exception(exc):
|
||||
# Exception is not retryable, raise immediately
|
||||
raise
|
||||
|
||||
# Check if we have more retries left
|
||||
if attempt < self.max_retries:
|
||||
# Calculate and apply backoff delay
|
||||
delay = self._calculate_delay(attempt)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
# Continue to next retry
|
||||
else:
|
||||
# No more retries, raise the exception
|
||||
raise
|
||||
|
||||
# Unreachable: loop always returns via handler success or raises
|
||||
msg = "Unexpected: retry loop completed without returning"
|
||||
raise RuntimeError(msg)
|
||||
@@ -0,0 +1,535 @@
|
||||
"""Tests for ModelRetryMiddleware functionality."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.model_retry import ModelRetryMiddleware
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
return cast(Runtime, object())
|
||||
|
||||
|
||||
def _make_request() -> ModelRequest:
|
||||
"""Create a minimal ModelRequest for testing."""
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="success")]))
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=[],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=cast("AgentState", {}), # type: ignore[name-defined]
|
||||
runtime=_fake_runtime(),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
|
||||
class TemporaryFailureModel:
|
||||
"""Model that fails a certain number of times before succeeding."""
|
||||
|
||||
def __init__(self, fail_count: int):
|
||||
"""Initialize with the number of times to fail.
|
||||
|
||||
Args:
|
||||
fail_count: Number of times to fail before succeeding.
|
||||
"""
|
||||
self.fail_count = fail_count
|
||||
self.attempt = 0
|
||||
|
||||
def __call__(self) -> AIMessage:
|
||||
"""Execute the model call.
|
||||
|
||||
Returns:
|
||||
Success message if attempt >= fail_count.
|
||||
|
||||
Raises:
|
||||
ValueError: If attempt < fail_count.
|
||||
"""
|
||||
self.attempt += 1
|
||||
if self.attempt <= self.fail_count:
|
||||
msg = f"Temporary failure {self.attempt}"
|
||||
raise ValueError(msg)
|
||||
return AIMessage(content=f"Success after {self.attempt} attempts")
|
||||
|
||||
|
||||
def test_model_retry_initialization_defaults() -> None:
|
||||
"""Test ModelRetryMiddleware initialization with default values."""
|
||||
retry = ModelRetryMiddleware()
|
||||
|
||||
assert retry.max_retries == 2
|
||||
assert retry.tools == []
|
||||
assert retry.on_failure == "raise"
|
||||
assert retry.backoff_factor == 2.0
|
||||
assert retry.initial_delay == 1.0
|
||||
assert retry.max_delay == 60.0
|
||||
assert retry.jitter is True
|
||||
|
||||
|
||||
def test_model_retry_initialization_custom() -> None:
|
||||
"""Test ModelRetryMiddleware initialization with custom values."""
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=5,
|
||||
retry_on=(ValueError, RuntimeError),
|
||||
on_failure="raise",
|
||||
backoff_factor=1.5,
|
||||
initial_delay=0.5,
|
||||
max_delay=30.0,
|
||||
jitter=False,
|
||||
)
|
||||
|
||||
assert retry.max_retries == 5
|
||||
assert retry.tools == []
|
||||
assert retry.retry_on == (ValueError, RuntimeError)
|
||||
assert retry.on_failure == "raise"
|
||||
assert retry.backoff_factor == 1.5
|
||||
assert retry.initial_delay == 0.5
|
||||
assert retry.max_delay == 30.0
|
||||
assert retry.jitter is False
|
||||
|
||||
|
||||
def test_model_retry_invalid_max_retries() -> None:
|
||||
"""Test ModelRetryMiddleware raises error for invalid max_retries."""
|
||||
with pytest.raises(ValueError, match="max_retries must be >= 0"):
|
||||
ModelRetryMiddleware(max_retries=-1)
|
||||
|
||||
|
||||
def test_model_retry_invalid_initial_delay() -> None:
|
||||
"""Test ModelRetryMiddleware raises error for invalid initial_delay."""
|
||||
with pytest.raises(ValueError, match="initial_delay must be >= 0"):
|
||||
ModelRetryMiddleware(initial_delay=-1.0)
|
||||
|
||||
|
||||
def test_model_retry_invalid_max_delay() -> None:
|
||||
"""Test ModelRetryMiddleware raises error for invalid max_delay."""
|
||||
with pytest.raises(ValueError, match="max_delay must be >= 0"):
|
||||
ModelRetryMiddleware(max_delay=-1.0)
|
||||
|
||||
|
||||
def test_model_retry_invalid_backoff_factor() -> None:
|
||||
"""Test ModelRetryMiddleware raises error for invalid backoff_factor."""
|
||||
with pytest.raises(ValueError, match="backoff_factor must be >= 0"):
|
||||
ModelRetryMiddleware(backoff_factor=-1.0)
|
||||
|
||||
|
||||
def test_model_retry_working_model_no_retry_needed() -> None:
|
||||
"""Test ModelRetryMiddleware with a working model (no retry needed)."""
|
||||
retry = ModelRetryMiddleware(max_retries=2, initial_delay=0.01, jitter=False)
|
||||
request = _make_request()
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = AIMessage(content="success")
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = retry.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "success"
|
||||
|
||||
|
||||
def test_model_retry_failing_model_raises() -> None:
|
||||
"""Test ModelRetryMiddleware with failing model raises exception."""
|
||||
|
||||
class FailingModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=2,
|
||||
initial_delay=0.01,
|
||||
jitter=False,
|
||||
on_failure="raise",
|
||||
)
|
||||
request = _make_request()
|
||||
request.model = FailingModel(messages=iter([]))
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
# Should raise the ValueError from the model after all retries
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
retry.wrap_model_call(request, mock_handler)
|
||||
|
||||
|
||||
def test_model_retry_succeeds_after_retries() -> None:
|
||||
"""Test ModelRetryMiddleware succeeds after temporary failures."""
|
||||
temp_fail = TemporaryFailureModel(fail_count=2)
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=3,
|
||||
initial_delay=0.01,
|
||||
jitter=False,
|
||||
)
|
||||
request = _make_request()
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = temp_fail()
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = retry.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert "Success after 3 attempts" in response.result[0].content
|
||||
|
||||
|
||||
def test_model_retry_specific_exceptions() -> None:
|
||||
"""Test ModelRetryMiddleware only retries specific exception types."""
|
||||
|
||||
class ValueErrorModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("ValueError occurred")
|
||||
|
||||
class RuntimeErrorModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise RuntimeError("RuntimeError occurred")
|
||||
|
||||
# Only retry ValueError
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=2,
|
||||
retry_on=(ValueError,),
|
||||
initial_delay=0.01,
|
||||
jitter=False,
|
||||
)
|
||||
|
||||
# Test that ValueError is retried (should see delay)
|
||||
request = _make_request()
|
||||
request.model = ValueErrorModel(messages=iter([]))
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
start_time = time.time()
|
||||
with pytest.raises(ValueError, match="ValueError occurred"):
|
||||
retry.wrap_model_call(request, mock_handler)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Should have retried with delays (at least 2 retries with 0.01s delay each)
|
||||
assert elapsed >= 0.02
|
||||
|
||||
# Test that RuntimeError is not retried (should fail immediately)
|
||||
request.model = RuntimeErrorModel(messages=iter([]))
|
||||
start_time = time.time()
|
||||
with pytest.raises(RuntimeError, match="RuntimeError occurred"):
|
||||
retry.wrap_model_call(request, mock_handler)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Should fail immediately without retries
|
||||
assert elapsed < 0.01
|
||||
|
||||
|
||||
def test_model_retry_custom_exception_filter() -> None:
|
||||
"""Test ModelRetryMiddleware with custom exception filter function."""
|
||||
|
||||
class CustomError(Exception):
|
||||
"""Custom exception with retry_me attribute."""
|
||||
|
||||
def __init__(self, message: str, retry_me: bool):
|
||||
"""Initialize custom error.
|
||||
|
||||
Args:
|
||||
message: Error message.
|
||||
retry_me: Whether this error should be retried.
|
||||
"""
|
||||
super().__init__(message)
|
||||
self.retry_me = retry_me
|
||||
|
||||
attempt_count = {"value": 0}
|
||||
|
||||
def should_retry(exc: Exception) -> bool:
|
||||
return isinstance(exc, CustomError) and exc.retry_me
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
initial_delay=0.01,
|
||||
jitter=False,
|
||||
)
|
||||
request = _make_request()
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
attempt_count["value"] += 1
|
||||
if attempt_count["value"] == 1:
|
||||
raise CustomError("Retryable error", retry_me=True)
|
||||
raise CustomError("Non-retryable error", retry_me=False)
|
||||
|
||||
with pytest.raises(CustomError, match="Non-retryable error"):
|
||||
retry.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Should retry once (attempt 1 with retry_me=True), then fail on attempt 2 (retry_me=False)
|
||||
assert attempt_count["value"] == 2
|
||||
|
||||
|
||||
def test_model_retry_backoff_timing() -> None:
|
||||
"""Test ModelRetryMiddleware applies correct backoff delays."""
|
||||
temp_fail = TemporaryFailureModel(fail_count=3)
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=3,
|
||||
initial_delay=0.1,
|
||||
backoff_factor=2.0,
|
||||
jitter=False,
|
||||
)
|
||||
request = _make_request()
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = temp_fail()
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
start_time = time.time()
|
||||
response = retry.wrap_model_call(request, mock_handler)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
# Expected delays: 0.1 + 0.2 + 0.4 = 0.7 seconds
|
||||
# Allow some margin for execution time
|
||||
assert elapsed >= 0.6, f"Expected at least 0.6s, got {elapsed}s"
|
||||
|
||||
|
||||
def test_model_retry_constant_backoff() -> None:
|
||||
"""Test ModelRetryMiddleware with constant backoff (backoff_factor=0)."""
|
||||
temp_fail = TemporaryFailureModel(fail_count=2)
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=2,
|
||||
initial_delay=0.1,
|
||||
backoff_factor=0.0, # Constant backoff
|
||||
jitter=False,
|
||||
)
|
||||
request = _make_request()
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = temp_fail()
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
start_time = time.time()
|
||||
response = retry.wrap_model_call(request, mock_handler)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
# Expected delays: 0.1 + 0.1 = 0.2 seconds (constant)
|
||||
assert elapsed >= 0.15, f"Expected at least 0.15s, got {elapsed}s"
|
||||
assert elapsed < 0.5, f"Expected less than 0.5s (exponential would be longer), got {elapsed}s"
|
||||
|
||||
|
||||
def test_model_retry_max_delay_cap() -> None:
|
||||
"""Test ModelRetryMiddleware caps delay at max_delay."""
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=5,
|
||||
initial_delay=1.0,
|
||||
backoff_factor=10.0, # Very aggressive backoff
|
||||
max_delay=2.0, # Cap at 2 seconds
|
||||
jitter=False,
|
||||
)
|
||||
|
||||
# Test delay calculation
|
||||
delay_0 = retry._calculate_delay(0) # 1.0
|
||||
delay_1 = retry._calculate_delay(1) # 10.0 -> capped to 2.0
|
||||
delay_2 = retry._calculate_delay(2) # 100.0 -> capped to 2.0
|
||||
|
||||
assert delay_0 == 1.0
|
||||
assert delay_1 == 2.0
|
||||
assert delay_2 == 2.0
|
||||
|
||||
|
||||
def test_model_retry_jitter_variation() -> None:
|
||||
"""Test ModelRetryMiddleware adds jitter to delays."""
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=1,
|
||||
initial_delay=1.0,
|
||||
backoff_factor=1.0,
|
||||
jitter=True,
|
||||
)
|
||||
|
||||
# Generate multiple delays and ensure they vary
|
||||
delays = [retry._calculate_delay(0) for _ in range(10)]
|
||||
|
||||
# All delays should be within ±25% of 1.0 (i.e., between 0.75 and 1.25)
|
||||
for delay in delays:
|
||||
assert 0.75 <= delay <= 1.25
|
||||
|
||||
# Delays should vary (not all the same)
|
||||
assert len(set(delays)) > 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_retry_async_working_model() -> None:
|
||||
"""Test ModelRetryMiddleware with async execution and working model."""
|
||||
retry = ModelRetryMiddleware(max_retries=2, initial_delay=0.01, jitter=False)
|
||||
request = _make_request()
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = AIMessage(content="success")
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = await retry.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert response.result[0].content == "success"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_retry_async_failing_model() -> None:
|
||||
"""Test ModelRetryMiddleware with async execution and failing model."""
|
||||
|
||||
class AsyncFailingModel(GenericFakeChatModel):
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=2,
|
||||
initial_delay=0.01,
|
||||
jitter=False,
|
||||
on_failure="raise",
|
||||
)
|
||||
request = _make_request()
|
||||
request.model = AsyncFailingModel(messages=iter([]))
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = await req.model.ainvoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
await retry.awrap_model_call(request, mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_retry_async_succeeds_after_retries() -> None:
|
||||
"""Test ModelRetryMiddleware async execution succeeds after temporary failures."""
|
||||
temp_fail = TemporaryFailureModel(fail_count=2)
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=3,
|
||||
initial_delay=0.01,
|
||||
jitter=False,
|
||||
)
|
||||
request = _make_request()
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = temp_fail()
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
response = await retry.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
assert "Success after 3 attempts" in response.result[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_retry_async_backoff_timing() -> None:
|
||||
"""Test ModelRetryMiddleware async applies correct backoff delays."""
|
||||
temp_fail = TemporaryFailureModel(fail_count=3)
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=3,
|
||||
initial_delay=0.1,
|
||||
backoff_factor=2.0,
|
||||
jitter=False,
|
||||
)
|
||||
request = _make_request()
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = temp_fail()
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
start_time = time.time()
|
||||
response = await retry.awrap_model_call(request, mock_handler)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert isinstance(response, ModelResponse)
|
||||
# Expected delays: 0.1 + 0.2 + 0.4 = 0.7 seconds
|
||||
assert elapsed >= 0.6, f"Expected at least 0.6s, got {elapsed}s"
|
||||
|
||||
|
||||
def test_model_retry_zero_retries() -> None:
|
||||
"""Test ModelRetryMiddleware with max_retries=0 (no retries)."""
|
||||
|
||||
class FailingModel(GenericFakeChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
retry = ModelRetryMiddleware(
|
||||
max_retries=0, # No retries
|
||||
on_failure="raise",
|
||||
)
|
||||
request = _make_request()
|
||||
request.model = FailingModel(messages=iter([]))
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
result = req.model.invoke([])
|
||||
return ModelResponse(result=[result])
|
||||
|
||||
# Should fail immediately without retries
|
||||
start_time = time.time()
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
retry.wrap_model_call(request, mock_handler)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Should fail immediately (no delay)
|
||||
assert elapsed < 0.1
|
||||
|
||||
|
||||
def test_model_retry_with_agent() -> None:
|
||||
"""Test ModelRetryMiddleware with agent.invoke."""
|
||||
attempt_counter = {"value": 0}
|
||||
|
||||
class FailingModel(BaseChatModel):
|
||||
"""Model that fails twice then succeeds."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
attempt_counter["value"] += 1
|
||||
if attempt_counter["value"] <= 2:
|
||||
raise ValueError(f"Failure {attempt_counter['value']}")
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="Success after retries"))]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "failing"
|
||||
|
||||
model = FailingModel()
|
||||
retry = ModelRetryMiddleware(max_retries=3, initial_delay=0.01, jitter=False)
|
||||
|
||||
agent = create_agent(model=model, middleware=[retry])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# Should have succeeded after retries
|
||||
assert len(result["messages"]) == 2
|
||||
assert result["messages"][1].content == "Success after retries"
|
||||
|
||||
|
||||
def test_model_retry_exhausted_with_agent() -> None:
|
||||
"""Test ModelRetryMiddleware with agent.invoke when all retries exhausted."""
|
||||
|
||||
class AlwaysFailingModel(BaseChatModel):
|
||||
"""Model that always fails."""
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
raise ValueError("Model failed")
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
return "failing"
|
||||
|
||||
model = AlwaysFailingModel()
|
||||
retry = ModelRetryMiddleware(max_retries=2, initial_delay=0.01, jitter=False)
|
||||
|
||||
agent = create_agent(model=model, middleware=[retry])
|
||||
|
||||
# Should fail after exhausting retries
|
||||
with pytest.raises(ValueError, match="Model failed"):
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
Reference in New Issue
Block a user