feat: support structured output retry middleware (#33663)

* attach the latest `AIMessage` to all `StructuredOutputError`s so that
relevant middleware can use as desired
* raise `StructuredOutputError` from `ProviderStrategy` logic in case of
failed parsing (so that we can retry from middleware)
* added a test suite w/ example custom middleware that retries for tool
+ provider strategy

Long term, we could add our own opinionated structured output retry
middleware, but this at least unblocks folks who want to use custom
retry logic in the short term :)

```py
class StructuredOutputRetryMiddleware(AgentMiddleware):
    """Retries model calls when structured output parsing fails."""

    def __init__(self, max_retries: int) -> None:
        self.max_retries = max_retries

    def wrap_model_call(
        self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
    ) -> ModelResponse:
        for attempt in range(self.max_retries + 1):
            try:
                return handler(request)
            except StructuredOutputError as exc:
                if attempt == self.max_retries:
                    raise

                ai_content = exc.ai_message.content
                error_message = (
                    f"Your previous response was:\n{ai_content}\n\n"
                    f"Error: {exc}. Please try again with a valid response."
                )
                request.messages.append(HumanMessage(content=error_message))
```
This commit is contained in:
Sydney Runkle
2025-10-29 08:41:44 -07:00
committed by GitHub
parent 78a2f86f70
commit 8aea6dd23a
4 changed files with 442 additions and 7 deletions

View File

@@ -33,6 +33,7 @@ from langchain.agents.structured_output import (
ProviderStrategy,
ProviderStrategyBinding,
ResponseFormat,
StructuredOutputError,
StructuredOutputValidationError,
ToolStrategy,
)
@@ -797,8 +798,16 @@ def create_agent( # noqa: PLR0915
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
effective_response_format.schema_spec
)
structured_response = provider_strategy_binding.parse(output)
return {"messages": [output], "structured_response": structured_response}
try:
structured_response = provider_strategy_binding.parse(output)
except Exception as exc: # noqa: BLE001
schema_name = getattr(
effective_response_format.schema_spec.schema, "__name__", "response_format"
)
validation_error = StructuredOutputValidationError(schema_name, exc, output)
raise validation_error
else:
return {"messages": [output], "structured_response": structured_response}
return {"messages": [output]}
# Handle structured output with tool strategy
@@ -812,11 +821,11 @@ def create_agent( # noqa: PLR0915
]
if structured_tool_calls:
exception: Exception | None = None
exception: StructuredOutputError | None = None
if len(structured_tool_calls) > 1:
# Handle multiple structured outputs error
tool_names = [tc["name"] for tc in structured_tool_calls]
exception = MultipleStructuredOutputsError(tool_names)
exception = MultipleStructuredOutputsError(tool_names, output)
should_retry, error_message = _handle_structured_output_error(
exception, effective_response_format
)
@@ -858,7 +867,7 @@ def create_agent( # noqa: PLR0915
"structured_response": structured_response,
}
except Exception as exc: # noqa: BLE001
exception = StructuredOutputValidationError(tool_call["name"], exc)
exception = StructuredOutputValidationError(tool_call["name"], exc, output)
should_retry, error_message = _handle_structured_output_error(
exception, effective_response_format
)

View File

@@ -34,17 +34,21 @@ SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
class StructuredOutputError(Exception):
"""Base class for structured output errors."""
ai_message: AIMessage
class MultipleStructuredOutputsError(StructuredOutputError):
"""Raised when model returns multiple structured output tool calls when only one is expected."""
def __init__(self, tool_names: list[str]) -> None:
def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None:
"""Initialize `MultipleStructuredOutputsError`.
Args:
tool_names: The names of the tools called for structured output.
ai_message: The AI message that contained the invalid multiple tool calls.
"""
self.tool_names = tool_names
self.ai_message = ai_message
super().__init__(
"Model incorrectly returned multiple structured responses "
@@ -55,15 +59,17 @@ class MultipleStructuredOutputsError(StructuredOutputError):
class StructuredOutputValidationError(StructuredOutputError):
"""Raised when structured output tool call arguments fail to parse according to the schema."""
def __init__(self, tool_name: str, source: Exception) -> None:
def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None:
"""Initialize `StructuredOutputValidationError`.
Args:
tool_name: The name of the tool that failed.
source: The exception that occurred.
ai_message: The AI message that contained the invalid structured output.
"""
self.tool_name = tool_name
self.source = source
self.ai_message = ai_message
super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")

View File

@@ -0,0 +1,369 @@
"""Tests for StructuredOutputRetryMiddleware functionality."""
from collections.abc import Callable
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from pydantic import BaseModel
from langchain.agents import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelRequest,
ModelResponse,
)
from langchain.agents.structured_output import StructuredOutputError, ToolStrategy
from tests.unit_tests.agents.model import FakeToolCallingModel
class StructuredOutputRetryMiddleware(AgentMiddleware):
"""Retries model calls when structured output parsing fails."""
def __init__(self, max_retries: int) -> None:
"""Initialize the structured output retry middleware.
Args:
max_retries: Maximum number of retry attempts.
"""
self.max_retries = max_retries
def wrap_model_call(
self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
"""Intercept and control model execution via handler callback.
Args:
request: The model request containing messages and configuration.
handler: The function to call the model.
Returns:
The model response.
Raises:
StructuredOutputError: If max retries exceeded without success.
"""
for attempt in range(self.max_retries + 1):
try:
return handler(request)
except StructuredOutputError as exc:
if attempt == self.max_retries:
raise
# Include both the AI message and error in a single human message
# to maintain valid chat history alternation
ai_content = exc.ai_message.content
error_message = (
f"Your previous response was:\n{ai_content}\n\n"
f"Error: {exc}. Please try again with a valid response."
)
request.messages.append(HumanMessage(content=error_message))
# This should never be reached, but satisfies type checker
return handler(request)
class WeatherReport(BaseModel):
"""Weather report schema for testing."""
temperature: float
conditions: str
@tool
def get_weather(city: str) -> str:
"""Get the weather for a given city.
Args:
city: The city to get weather for.
Returns:
Weather information for the city.
"""
return f"The weather in {city} is sunny and 72 degrees."
def test_structured_output_retry_first_attempt_invalid() -> None:
"""Test structured output retry when first two attempts have invalid output."""
# First two attempts have invalid tool arguments, third attempt succeeds
# The model will call the WeatherReport structured output tool
tool_calls = [
# First attempt - invalid: wrong type for temperature
[
{
"name": "WeatherReport",
"id": "1",
"args": {"temperature": "not-a-float", "conditions": "sunny"},
}
],
# Second attempt - invalid: missing required field
[{"name": "WeatherReport", "id": "2", "args": {"temperature": 72.5}}],
# Third attempt - valid
[
{
"name": "WeatherReport",
"id": "3",
"args": {"temperature": 72.5, "conditions": "sunny"},
}
],
]
model = FakeToolCallingModel(tool_calls=tool_calls)
retry_middleware = StructuredOutputRetryMiddleware(max_retries=2)
agent = create_agent(
model=model,
tools=[get_weather],
middleware=[retry_middleware],
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
checkpointer=InMemorySaver(),
)
result = agent.invoke(
{"messages": [HumanMessage("What's the weather in Tokyo?")]},
{"configurable": {"thread_id": "test"}},
)
# Verify we got a structured response
assert "structured_response" in result
structured = result["structured_response"]
assert isinstance(structured, WeatherReport)
assert structured.temperature == 72.5
assert structured.conditions == "sunny"
# Verify the model was called 3 times (initial + 2 retries)
assert model.index == 3
def test_structured_output_retry_exceeds_max_retries() -> None:
"""Test structured output retry raises error when max retries exceeded."""
# All three attempts return invalid arguments
tool_calls = [
[
{
"name": "WeatherReport",
"id": "1",
"args": {"temperature": "invalid", "conditions": "sunny"},
}
],
[
{
"name": "WeatherReport",
"id": "2",
"args": {"temperature": "also-invalid", "conditions": "cloudy"},
}
],
[
{
"name": "WeatherReport",
"id": "3",
"args": {"temperature": "still-invalid", "conditions": "rainy"},
}
],
]
model = FakeToolCallingModel(tool_calls=tool_calls)
retry_middleware = StructuredOutputRetryMiddleware(max_retries=2)
agent = create_agent(
model=model,
tools=[get_weather],
middleware=[retry_middleware],
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
# No checkpointer - we expect this to fail
)
# Should raise StructuredOutputError after exhausting retries
with pytest.raises(StructuredOutputError):
agent.invoke(
{"messages": [HumanMessage("What's the weather in Tokyo?")]},
)
# Verify the model was called 3 times (initial + 2 retries)
assert model.index == 3
def test_structured_output_retry_succeeds_first_attempt() -> None:
"""Test structured output retry when first attempt succeeds (no retry needed)."""
# First attempt returns valid structured output
tool_calls = [
[
{
"name": "WeatherReport",
"id": "1",
"args": {"temperature": 68.0, "conditions": "cloudy"},
}
],
]
model = FakeToolCallingModel(tool_calls=tool_calls)
retry_middleware = StructuredOutputRetryMiddleware(max_retries=2)
agent = create_agent(
model=model,
tools=[get_weather],
middleware=[retry_middleware],
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
checkpointer=InMemorySaver(),
)
result = agent.invoke(
{"messages": [HumanMessage("What's the weather in Paris?")]},
{"configurable": {"thread_id": "test"}},
)
# Verify we got a structured response
assert "structured_response" in result
structured = result["structured_response"]
assert isinstance(structured, WeatherReport)
assert structured.temperature == 68.0
assert structured.conditions == "cloudy"
# Verify the model was called only once
assert model.index == 1
def test_structured_output_retry_validation_error() -> None:
"""Test structured output retry with schema validation errors."""
# First attempt has wrong type, second has missing field, third succeeds
tool_calls = [
[
{
"name": "WeatherReport",
"id": "1",
"args": {"temperature": "seventy-two", "conditions": "sunny"},
}
],
[{"name": "WeatherReport", "id": "2", "args": {"temperature": 72.5}}],
[
{
"name": "WeatherReport",
"id": "3",
"args": {"temperature": 72.5, "conditions": "partly cloudy"},
}
],
]
model = FakeToolCallingModel(tool_calls=tool_calls)
retry_middleware = StructuredOutputRetryMiddleware(max_retries=2)
agent = create_agent(
model=model,
tools=[get_weather],
middleware=[retry_middleware],
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
checkpointer=InMemorySaver(),
)
result = agent.invoke(
{"messages": [HumanMessage("What's the weather in London?")]},
{"configurable": {"thread_id": "test"}},
)
# Verify we got a structured response
assert "structured_response" in result
structured = result["structured_response"]
assert isinstance(structured, WeatherReport)
assert structured.temperature == 72.5
assert structured.conditions == "partly cloudy"
# Verify the model was called 3 times
assert model.index == 3
def test_structured_output_retry_zero_retries() -> None:
"""Test structured output retry with max_retries=0 (no retries allowed)."""
# First attempt returns invalid arguments
tool_calls = [
[
{
"name": "WeatherReport",
"id": "1",
"args": {"temperature": "invalid", "conditions": "sunny"},
}
],
[
{
"name": "WeatherReport",
"id": "2",
"args": {"temperature": 72.5, "conditions": "sunny"},
}
], # Would succeed if retried
]
model = FakeToolCallingModel(tool_calls=tool_calls)
retry_middleware = StructuredOutputRetryMiddleware(max_retries=0)
agent = create_agent(
model=model,
tools=[get_weather],
middleware=[retry_middleware],
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
checkpointer=InMemorySaver(),
)
# Should fail immediately without retrying
with pytest.raises(StructuredOutputError):
agent.invoke(
{"messages": [HumanMessage("What's the weather in Berlin?")]},
{"configurable": {"thread_id": "test"}},
)
# Verify the model was called only once (no retries)
assert model.index == 1
def test_structured_output_retry_preserves_messages() -> None:
"""Test structured output retry preserves error feedback in messages."""
# First attempt invalid, second succeeds
tool_calls = [
[
{
"name": "WeatherReport",
"id": "1",
"args": {"temperature": "invalid", "conditions": "rainy"},
}
],
[
{
"name": "WeatherReport",
"id": "2",
"args": {"temperature": 75.0, "conditions": "rainy"},
}
],
]
model = FakeToolCallingModel(tool_calls=tool_calls)
retry_middleware = StructuredOutputRetryMiddleware(max_retries=1)
agent = create_agent(
model=model,
tools=[get_weather],
middleware=[retry_middleware],
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
checkpointer=InMemorySaver(),
)
result = agent.invoke(
{"messages": [HumanMessage("What's the weather in Seattle?")]},
{"configurable": {"thread_id": "test"}},
)
# Verify structured response is correct
assert "structured_response" in result
structured = result["structured_response"]
assert structured.temperature == 75.0
assert structured.conditions == "rainy"
# Verify messages include the retry feedback
messages = result["messages"]
human_messages = [m for m in messages if isinstance(m, HumanMessage)]
# Should have at least 2 human messages: initial + retry feedback
assert len(human_messages) >= 2
# The retry feedback message should contain error information
retry_message = human_messages[-1]
assert "Error:" in retry_message.content
assert "Please try again" in retry_message.content

View File

@@ -610,6 +610,35 @@ class TestResponseFormatAsToolStrategy:
)
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
def test_validation_error_with_invalid_response(self) -> None:
"""Test that StructuredOutputValidationError is raised when tool strategy receives invalid response."""
tool_calls = [
[
{
"name": "WeatherBaseModel",
"id": "1",
"args": {"invalid_field": "wrong_data", "another_bad_field": 123},
},
],
]
model = FakeToolCallingModel(tool_calls=tool_calls)
agent = create_agent(
model,
[],
response_format=ToolStrategy(
WeatherBaseModel,
handle_errors=False, # Disable retry to ensure error is raised
),
)
with pytest.raises(
StructuredOutputValidationError,
match=".*WeatherBaseModel.*",
):
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
class TestResponseFormatAsProviderStrategy:
def test_pydantic_model(self) -> None:
@@ -630,6 +659,28 @@ class TestResponseFormatAsProviderStrategy:
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
assert len(response["messages"]) == 4
def test_validation_error_with_invalid_response(self) -> None:
"""Test that StructuredOutputValidationError is raised when provider strategy receives invalid response."""
tool_calls = [
[{"args": {}, "id": "1", "name": "get_weather"}],
]
# But we're using WeatherBaseModel which has different field requirements
model = FakeToolCallingModel[dict](
tool_calls=tool_calls,
structured_response={"invalid": "data"}, # Wrong structure
)
agent = create_agent(
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
)
with pytest.raises(
StructuredOutputValidationError,
match=".*WeatherBaseModel.*",
):
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
def test_dataclass(self) -> None:
"""Test response_format as ProviderStrategy with dataclass."""
tool_calls = [