Compare commits

...

1 Commits

Author SHA1 Message Date
Sydney Runkle
b3d26b0990 retry example 2025-10-24 09:56:31 -04:00
4 changed files with 342 additions and 33 deletions

View File

@@ -797,8 +797,27 @@ def create_agent( # noqa: PLR0915
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
effective_response_format.schema_spec
)
structured_response = provider_strategy_binding.parse(output)
return {"messages": [output], "structured_response": structured_response}
# Retry logic for provider strategy
last_exception: Exception | None = None
for attempt in range(effective_response_format.max_retries + 1):
try:
structured_response = provider_strategy_binding.parse(output)
return {"messages": [output], "structured_response": structured_response}
except Exception as exc:
last_exception = exc
if attempt >= effective_response_format.max_retries:
break
# Handle failure after retries exhausted
if effective_response_format.on_failure == "raise":
raise last_exception # type: ignore[misc]
elif callable(effective_response_format.on_failure):
msg = effective_response_format.on_failure(last_exception) # type: ignore[arg-type]
raise StructuredOutputValidationError(
effective_response_format.schema_spec.name, Exception(msg)
) from last_exception
return {"messages": [output]}
# Handle structured output with tool strategy
@@ -834,37 +853,47 @@ def create_agent( # noqa: PLR0915
]
return {"messages": [output, *tool_messages]}
# Handle single structured output
# Handle single structured output with retry
tool_call = structured_tool_calls[0]
try:
structured_tool_binding = structured_output_tools[tool_call["name"]]
structured_response = structured_tool_binding.parse(tool_call["args"])
structured_tool_binding = structured_output_tools[tool_call["name"]]
tool_message_content = (
effective_response_format.tool_message_content
if effective_response_format.tool_message_content
else f"Returning structured response: {structured_response}"
)
# Retry logic for tool strategy
last_exception: Exception | None = None
for attempt in range(effective_response_format.max_retries + 1):
try:
structured_response = structured_tool_binding.parse(tool_call["args"])
return {
"messages": [
output,
ToolMessage(
content=tool_message_content,
tool_call_id=tool_call["id"],
name=tool_call["name"],
),
],
"structured_response": structured_response,
}
except Exception as exc: # noqa: BLE001
exception = StructuredOutputValidationError(tool_call["name"], exc)
should_retry, error_message = _handle_structured_output_error(
exception, effective_response_format
)
if not should_retry:
raise exception
tool_message_content = (
effective_response_format.tool_message_content
if effective_response_format.tool_message_content
else f"Returning structured response: {structured_response}"
)
return {
"messages": [
output,
ToolMessage(
content=tool_message_content,
tool_call_id=tool_call["id"],
name=tool_call["name"],
),
],
"structured_response": structured_response,
}
except Exception as exc: # noqa: BLE001
last_exception = exc
if attempt >= effective_response_format.max_retries:
break
# Handle failure after retries exhausted
exception = StructuredOutputValidationError(tool_call["name"], last_exception) # type: ignore[arg-type]
# Check if we should send error to model for self-correction
should_retry, error_message = _handle_structured_output_error(
exception, effective_response_format
)
if should_retry:
# Model self-correction mode
return {
"messages": [
output,
@@ -876,6 +905,17 @@ def create_agent( # noqa: PLR0915
],
}
# Apply on_failure behavior
if effective_response_format.on_failure == "raise":
raise exception
elif callable(effective_response_format.on_failure):
msg = effective_response_format.on_failure(last_exception) # type: ignore[arg-type]
raise StructuredOutputValidationError(
tool_call["name"], Exception(msg)
) from last_exception
else:
raise exception
return {"messages": [output]}
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:

View File

@@ -199,6 +199,16 @@ class ToolStrategy(Generic[SchemaT]):
- `False`: No retry, let exceptions propagate
"""
max_retries: int
"""Maximum number of retry attempts for parsing failures (default: 0)."""
on_failure: Literal["raise"] | Callable[[Exception], str]
"""Behavior when retries are exhausted.
- `"raise"`: Raise the exception (default)
- `Callable[[Exception], str]`: Custom error message function
"""
def __init__(
self,
schema: type[SchemaT],
@@ -209,15 +219,23 @@ class ToolStrategy(Generic[SchemaT]):
| type[Exception]
| tuple[type[Exception], ...]
| Callable[[Exception], str] = True,
max_retries: int = 0,
on_failure: Literal["raise"] | Callable[[Exception], str] = "raise",
) -> None:
"""Initialize `ToolStrategy`.
Initialize `ToolStrategy` with schemas, tool message content, and error handling
strategy.
Args:
schema: The schema for structured output.
tool_message_content: Custom content for tool response messages.
handle_errors: Error handling strategy for customizing error messages.
max_retries: Maximum number of retry attempts (0 = no retry).
on_failure: Behavior when retries exhausted (`"raise"` or custom function).
"""
self.schema = schema
self.tool_message_content = tool_message_content
self.handle_errors = handle_errors
self.max_retries = max_retries
self.on_failure = on_failure
def _iter_variants(schema: Any) -> Iterable[Any]:
"""Yield leaf variants from Union and JSON Schema oneOf."""
@@ -246,13 +264,34 @@ class ProviderStrategy(Generic[SchemaT]):
schema_spec: _SchemaSpec[SchemaT]
"""Schema spec for native mode."""
max_retries: int
"""Maximum number of retry attempts for parsing failures (default: 0)."""
on_failure: Literal["raise"] | Callable[[Exception], str]
"""Behavior when retries are exhausted.
- `"raise"`: Raise the exception (default)
- `Callable[[Exception], str]`: Custom error message function
"""
def __init__(
self,
schema: type[SchemaT],
*,
max_retries: int = 0,
on_failure: Literal["raise"] | Callable[[Exception], str] = "raise",
) -> None:
"""Initialize ProviderStrategy with schema."""
"""Initialize ProviderStrategy with schema.
Args:
schema: The schema for structured output.
max_retries: Maximum number of retry attempts (0 = no retry).
on_failure: Behavior when retries exhausted (`"raise"` or custom function).
"""
self.schema = schema
self.schema_spec = _SchemaSpec(schema)
self.max_retries = max_retries
self.on_failure = on_failure
def to_model_kwargs(self) -> dict[str, Any]:
"""Convert to kwargs to bind to a model to force structured output."""

View File

@@ -0,0 +1,230 @@
"""Unit tests for structured output retry functionality."""
from unittest.mock import Mock
import pytest
from langchain_core.messages import AIMessage
from pydantic import BaseModel, ValidationError
from langchain.agents.structured_output import (
OutputToolBinding,
ProviderStrategy,
ProviderStrategyBinding,
StructuredOutputValidationError,
ToolStrategy,
_SchemaSpec,
)
class ResponseSchema(BaseModel):
"""Test schema for structured output."""
value: int
class TestProviderStrategyRetry:
"""Test retry functionality for ProviderStrategy."""
def test_provider_strategy_default_no_retry(self):
"""Test that ProviderStrategy defaults to no retry."""
strategy = ProviderStrategy(ResponseSchema)
assert strategy.max_retries == 0
assert strategy.on_failure == "raise"
def test_provider_strategy_with_max_retries(self):
"""Test ProviderStrategy with max_retries configured."""
strategy = ProviderStrategy(ResponseSchema, max_retries=3)
assert strategy.max_retries == 3
def test_provider_strategy_with_custom_on_failure(self):
"""Test ProviderStrategy with custom on_failure function."""
def custom_failure(exc: Exception) -> str:
return f"Custom error: {exc}"
strategy = ProviderStrategy(ResponseSchema, on_failure=custom_failure)
assert callable(strategy.on_failure)
def test_provider_binding_retry_succeeds_on_second_attempt(self):
"""Test that retry succeeds when second attempt works."""
binding = ProviderStrategyBinding.from_schema_spec(_SchemaSpec(ResponseSchema))
# Mock parse to fail once then succeed
original_parse = binding.parse
call_count = 0
def mock_parse(msg):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("First attempt fails")
return ResponseSchema(value=42)
binding.parse = mock_parse # type: ignore[method-assign]
# Simulate retry logic (would be in factory.py)
max_retries = 2
last_exception = None
for attempt in range(max_retries + 1):
try:
result = binding.parse(Mock())
assert result.value == 42
assert call_count == 2 # Failed once, succeeded on retry
break
except Exception as exc:
last_exception = exc
if attempt >= max_retries:
raise
def test_provider_binding_retry_exhausted(self):
"""Test that retries are exhausted and exception is raised."""
binding = ProviderStrategyBinding.from_schema_spec(_SchemaSpec(ResponseSchema))
# Mock parse to always fail
def mock_parse(msg):
raise ValueError("Always fails")
binding.parse = mock_parse # type: ignore[method-assign]
# Simulate retry logic
max_retries = 2
last_exception = None
attempts = 0
for attempt in range(max_retries + 1):
try:
binding.parse(Mock())
except Exception as exc:
attempts += 1
last_exception = exc
if attempt >= max_retries:
break
assert attempts == 3 # Initial + 2 retries
assert last_exception is not None
assert "Always fails" in str(last_exception)
class TestToolStrategyRetry:
"""Test retry functionality for ToolStrategy."""
def test_tool_strategy_default_no_retry(self):
"""Test that ToolStrategy defaults to no retry."""
strategy = ToolStrategy(ResponseSchema)
assert strategy.max_retries == 0
assert strategy.on_failure == "raise"
def test_tool_strategy_with_max_retries(self):
"""Test ToolStrategy with max_retries configured."""
strategy = ToolStrategy(ResponseSchema, max_retries=3)
assert strategy.max_retries == 3
def test_tool_strategy_retains_handle_errors(self):
"""Test that ToolStrategy retains handle_errors for backward compatibility."""
strategy = ToolStrategy(ResponseSchema, handle_errors=False, max_retries=2)
assert strategy.handle_errors is False
assert strategy.max_retries == 2
def test_tool_strategy_custom_error_message(self):
"""Test ToolStrategy with custom error message function."""
def custom_error(exc: Exception) -> str:
return f"Validation failed: {exc}"
strategy = ToolStrategy(ResponseSchema, handle_errors=custom_error)
assert callable(strategy.handle_errors)
def test_tool_binding_retry_succeeds_on_third_attempt(self):
"""Test that retry succeeds when third attempt works."""
binding = OutputToolBinding.from_schema_spec(_SchemaSpec(ResponseSchema))
# Mock parse to fail twice then succeed
call_count = 0
def mock_parse(args):
nonlocal call_count
call_count += 1
if call_count < 3:
raise ValidationError.from_exception_data(
"ResponseSchema",
[{"type": "missing", "loc": ("value",), "msg": "Field required"}],
)
return ResponseSchema(value=100)
binding.parse = mock_parse # type: ignore[method-assign]
# Simulate retry logic
max_retries = 3
for attempt in range(max_retries + 1):
try:
result = binding.parse({"value": 100})
assert result.value == 100
assert call_count == 3 # Failed twice, succeeded on third
break
except Exception:
if attempt >= max_retries:
raise
def test_tool_binding_retry_with_on_failure_callable(self):
"""Test custom on_failure function is applied correctly."""
def custom_failure(exc: Exception) -> str:
return f"Custom message: {type(exc).__name__}"
binding = OutputToolBinding.from_schema_spec(_SchemaSpec(ResponseSchema))
def mock_parse(args):
raise ValidationError.from_exception_data(
"ResponseSchema",
[{"type": "int_type", "loc": ("value",), "msg": "Input should be an integer"}],
)
binding.parse = mock_parse # type: ignore[method-assign]
# Simulate retry with custom failure
max_retries = 1
last_exception = None
for attempt in range(max_retries + 1):
try:
binding.parse({"value": "not_an_int"})
except Exception as exc:
last_exception = exc
if attempt >= max_retries:
break
# Apply custom failure handler
error_msg = custom_failure(last_exception) # type: ignore[arg-type]
assert "Custom message: ValidationError" in error_msg
class TestRetryBehavior:
"""Test overall retry behavior."""
def test_max_retries_zero_means_no_retry(self):
"""Test that max_retries=0 means single attempt only."""
strategy = ToolStrategy(ResponseSchema, max_retries=0)
# With max_retries=0, should only attempt once
attempts = 0
for attempt in range(strategy.max_retries + 1):
attempts += 1
assert attempts == 1
def test_max_retries_three_means_four_attempts(self):
"""Test that max_retries=3 means 4 total attempts."""
strategy = ProviderStrategy(ResponseSchema, max_retries=3)
# With max_retries=3, should attempt 4 times total
attempts = 0
for attempt in range(strategy.max_retries + 1):
attempts += 1
assert attempts == 4
def test_on_failure_raise_is_default(self):
"""Test that on_failure defaults to 'raise'."""
tool_strategy = ToolStrategy(ResponseSchema)
provider_strategy = ProviderStrategy(ResponseSchema)
assert tool_strategy.on_failure == "raise"
assert provider_strategy.on_failure == "raise"

View File

@@ -1743,7 +1743,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.0.0"
version = "1.0.1"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },