mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat: support structured output retry middleware (#33663)
* attach the latest `AIMessage` to all `StructuredOutputError`s so that
relevant middleware can use as desired
* raise `StructuredOutputError` from `ProviderStrategy` logic in case of
failed parsing (so that we can retry from middleware)
* added a test suite w/ example custom middleware that retries for tool
+ provider strategy
Long term, we could add our own opinionated structured output retry
middleware, but this at least unblocks folks who want to use custom
retry logic in the short term :)
```py
class StructuredOutputRetryMiddleware(AgentMiddleware):
"""Retries model calls when structured output parsing fails."""
def __init__(self, max_retries: int) -> None:
self.max_retries = max_retries
def wrap_model_call(
self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
for attempt in range(self.max_retries + 1):
try:
return handler(request)
except StructuredOutputError as exc:
if attempt == self.max_retries:
raise
ai_content = exc.ai_message.content
error_message = (
f"Your previous response was:\n{ai_content}\n\n"
f"Error: {exc}. Please try again with a valid response."
)
request.messages.append(HumanMessage(content=error_message))
```
This commit is contained in:
@@ -33,6 +33,7 @@ from langchain.agents.structured_output import (
|
||||
ProviderStrategy,
|
||||
ProviderStrategyBinding,
|
||||
ResponseFormat,
|
||||
StructuredOutputError,
|
||||
StructuredOutputValidationError,
|
||||
ToolStrategy,
|
||||
)
|
||||
@@ -797,8 +798,16 @@ def create_agent( # noqa: PLR0915
|
||||
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
effective_response_format.schema_spec
|
||||
)
|
||||
structured_response = provider_strategy_binding.parse(output)
|
||||
return {"messages": [output], "structured_response": structured_response}
|
||||
try:
|
||||
structured_response = provider_strategy_binding.parse(output)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
schema_name = getattr(
|
||||
effective_response_format.schema_spec.schema, "__name__", "response_format"
|
||||
)
|
||||
validation_error = StructuredOutputValidationError(schema_name, exc, output)
|
||||
raise validation_error
|
||||
else:
|
||||
return {"messages": [output], "structured_response": structured_response}
|
||||
return {"messages": [output]}
|
||||
|
||||
# Handle structured output with tool strategy
|
||||
@@ -812,11 +821,11 @@ def create_agent( # noqa: PLR0915
|
||||
]
|
||||
|
||||
if structured_tool_calls:
|
||||
exception: Exception | None = None
|
||||
exception: StructuredOutputError | None = None
|
||||
if len(structured_tool_calls) > 1:
|
||||
# Handle multiple structured outputs error
|
||||
tool_names = [tc["name"] for tc in structured_tool_calls]
|
||||
exception = MultipleStructuredOutputsError(tool_names)
|
||||
exception = MultipleStructuredOutputsError(tool_names, output)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
exception, effective_response_format
|
||||
)
|
||||
@@ -858,7 +867,7 @@ def create_agent( # noqa: PLR0915
|
||||
"structured_response": structured_response,
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exception = StructuredOutputValidationError(tool_call["name"], exc)
|
||||
exception = StructuredOutputValidationError(tool_call["name"], exc, output)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
exception, effective_response_format
|
||||
)
|
||||
|
||||
@@ -34,17 +34,21 @@ SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
|
||||
class StructuredOutputError(Exception):
|
||||
"""Base class for structured output errors."""
|
||||
|
||||
ai_message: AIMessage
|
||||
|
||||
|
||||
class MultipleStructuredOutputsError(StructuredOutputError):
|
||||
"""Raised when model returns multiple structured output tool calls when only one is expected."""
|
||||
|
||||
def __init__(self, tool_names: list[str]) -> None:
|
||||
def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None:
|
||||
"""Initialize `MultipleStructuredOutputsError`.
|
||||
|
||||
Args:
|
||||
tool_names: The names of the tools called for structured output.
|
||||
ai_message: The AI message that contained the invalid multiple tool calls.
|
||||
"""
|
||||
self.tool_names = tool_names
|
||||
self.ai_message = ai_message
|
||||
|
||||
super().__init__(
|
||||
"Model incorrectly returned multiple structured responses "
|
||||
@@ -55,15 +59,17 @@ class MultipleStructuredOutputsError(StructuredOutputError):
|
||||
class StructuredOutputValidationError(StructuredOutputError):
|
||||
"""Raised when structured output tool call arguments fail to parse according to the schema."""
|
||||
|
||||
def __init__(self, tool_name: str, source: Exception) -> None:
|
||||
def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None:
|
||||
"""Initialize `StructuredOutputValidationError`.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool that failed.
|
||||
source: The exception that occurred.
|
||||
ai_message: The AI message that contained the invalid structured output.
|
||||
"""
|
||||
self.tool_name = tool_name
|
||||
self.source = source
|
||||
self.ai_message = ai_message
|
||||
super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,369 @@
|
||||
"""Tests for StructuredOutputRetryMiddleware functionality."""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain.agents.structured_output import StructuredOutputError, ToolStrategy
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
class StructuredOutputRetryMiddleware(AgentMiddleware):
|
||||
"""Retries model calls when structured output parsing fails."""
|
||||
|
||||
def __init__(self, max_retries: int) -> None:
|
||||
"""Initialize the structured output retry middleware.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
"""
|
||||
self.max_retries = max_retries
|
||||
|
||||
def wrap_model_call(
|
||||
self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
) -> ModelResponse:
|
||||
"""Intercept and control model execution via handler callback.
|
||||
|
||||
Args:
|
||||
request: The model request containing messages and configuration.
|
||||
handler: The function to call the model.
|
||||
|
||||
Returns:
|
||||
The model response.
|
||||
|
||||
Raises:
|
||||
StructuredOutputError: If max retries exceeded without success.
|
||||
"""
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return handler(request)
|
||||
except StructuredOutputError as exc:
|
||||
if attempt == self.max_retries:
|
||||
raise
|
||||
|
||||
# Include both the AI message and error in a single human message
|
||||
# to maintain valid chat history alternation
|
||||
ai_content = exc.ai_message.content
|
||||
error_message = (
|
||||
f"Your previous response was:\n{ai_content}\n\n"
|
||||
f"Error: {exc}. Please try again with a valid response."
|
||||
)
|
||||
request.messages.append(HumanMessage(content=error_message))
|
||||
|
||||
# This should never be reached, but satisfies type checker
|
||||
return handler(request)
|
||||
|
||||
|
||||
class WeatherReport(BaseModel):
|
||||
"""Weather report schema for testing."""
|
||||
|
||||
temperature: float
|
||||
conditions: str
|
||||
|
||||
|
||||
@tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get the weather for a given city.
|
||||
|
||||
Args:
|
||||
city: The city to get weather for.
|
||||
|
||||
Returns:
|
||||
Weather information for the city.
|
||||
"""
|
||||
return f"The weather in {city} is sunny and 72 degrees."
|
||||
|
||||
|
||||
def test_structured_output_retry_first_attempt_invalid() -> None:
|
||||
"""Test structured output retry when first two attempts have invalid output."""
|
||||
# First two attempts have invalid tool arguments, third attempt succeeds
|
||||
# The model will call the WeatherReport structured output tool
|
||||
tool_calls = [
|
||||
# First attempt - invalid: wrong type for temperature
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "1",
|
||||
"args": {"temperature": "not-a-float", "conditions": "sunny"},
|
||||
}
|
||||
],
|
||||
# Second attempt - invalid: missing required field
|
||||
[{"name": "WeatherReport", "id": "2", "args": {"temperature": 72.5}}],
|
||||
# Third attempt - valid
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "3",
|
||||
"args": {"temperature": 72.5, "conditions": "sunny"},
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
retry_middleware = StructuredOutputRetryMiddleware(max_retries=2)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[get_weather],
|
||||
middleware=[retry_middleware],
|
||||
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather in Tokyo?")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Verify we got a structured response
|
||||
assert "structured_response" in result
|
||||
structured = result["structured_response"]
|
||||
assert isinstance(structured, WeatherReport)
|
||||
assert structured.temperature == 72.5
|
||||
assert structured.conditions == "sunny"
|
||||
|
||||
# Verify the model was called 3 times (initial + 2 retries)
|
||||
assert model.index == 3
|
||||
|
||||
|
||||
def test_structured_output_retry_exceeds_max_retries() -> None:
|
||||
"""Test structured output retry raises error when max retries exceeded."""
|
||||
# All three attempts return invalid arguments
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "1",
|
||||
"args": {"temperature": "invalid", "conditions": "sunny"},
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "2",
|
||||
"args": {"temperature": "also-invalid", "conditions": "cloudy"},
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "3",
|
||||
"args": {"temperature": "still-invalid", "conditions": "rainy"},
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
retry_middleware = StructuredOutputRetryMiddleware(max_retries=2)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[get_weather],
|
||||
middleware=[retry_middleware],
|
||||
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
|
||||
# No checkpointer - we expect this to fail
|
||||
)
|
||||
|
||||
# Should raise StructuredOutputError after exhausting retries
|
||||
with pytest.raises(StructuredOutputError):
|
||||
agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather in Tokyo?")]},
|
||||
)
|
||||
|
||||
# Verify the model was called 3 times (initial + 2 retries)
|
||||
assert model.index == 3
|
||||
|
||||
|
||||
def test_structured_output_retry_succeeds_first_attempt() -> None:
|
||||
"""Test structured output retry when first attempt succeeds (no retry needed)."""
|
||||
# First attempt returns valid structured output
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "1",
|
||||
"args": {"temperature": 68.0, "conditions": "cloudy"},
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
retry_middleware = StructuredOutputRetryMiddleware(max_retries=2)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[get_weather],
|
||||
middleware=[retry_middleware],
|
||||
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather in Paris?")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Verify we got a structured response
|
||||
assert "structured_response" in result
|
||||
structured = result["structured_response"]
|
||||
assert isinstance(structured, WeatherReport)
|
||||
assert structured.temperature == 68.0
|
||||
assert structured.conditions == "cloudy"
|
||||
|
||||
# Verify the model was called only once
|
||||
assert model.index == 1
|
||||
|
||||
|
||||
def test_structured_output_retry_validation_error() -> None:
|
||||
"""Test structured output retry with schema validation errors."""
|
||||
# First attempt has wrong type, second has missing field, third succeeds
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "1",
|
||||
"args": {"temperature": "seventy-two", "conditions": "sunny"},
|
||||
}
|
||||
],
|
||||
[{"name": "WeatherReport", "id": "2", "args": {"temperature": 72.5}}],
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "3",
|
||||
"args": {"temperature": 72.5, "conditions": "partly cloudy"},
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
retry_middleware = StructuredOutputRetryMiddleware(max_retries=2)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[get_weather],
|
||||
middleware=[retry_middleware],
|
||||
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather in London?")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Verify we got a structured response
|
||||
assert "structured_response" in result
|
||||
structured = result["structured_response"]
|
||||
assert isinstance(structured, WeatherReport)
|
||||
assert structured.temperature == 72.5
|
||||
assert structured.conditions == "partly cloudy"
|
||||
|
||||
# Verify the model was called 3 times
|
||||
assert model.index == 3
|
||||
|
||||
|
||||
def test_structured_output_retry_zero_retries() -> None:
|
||||
"""Test structured output retry with max_retries=0 (no retries allowed)."""
|
||||
# First attempt returns invalid arguments
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "1",
|
||||
"args": {"temperature": "invalid", "conditions": "sunny"},
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "2",
|
||||
"args": {"temperature": 72.5, "conditions": "sunny"},
|
||||
}
|
||||
], # Would succeed if retried
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
retry_middleware = StructuredOutputRetryMiddleware(max_retries=0)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[get_weather],
|
||||
middleware=[retry_middleware],
|
||||
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
# Should fail immediately without retrying
|
||||
with pytest.raises(StructuredOutputError):
|
||||
agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather in Berlin?")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Verify the model was called only once (no retries)
|
||||
assert model.index == 1
|
||||
|
||||
|
||||
def test_structured_output_retry_preserves_messages() -> None:
|
||||
"""Test structured output retry preserves error feedback in messages."""
|
||||
# First attempt invalid, second succeeds
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "1",
|
||||
"args": {"temperature": "invalid", "conditions": "rainy"},
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"name": "WeatherReport",
|
||||
"id": "2",
|
||||
"args": {"temperature": 75.0, "conditions": "rainy"},
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
retry_middleware = StructuredOutputRetryMiddleware(max_retries=1)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[get_weather],
|
||||
middleware=[retry_middleware],
|
||||
response_format=ToolStrategy(schema=WeatherReport, handle_errors=False),
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{"messages": [HumanMessage("What's the weather in Seattle?")]},
|
||||
{"configurable": {"thread_id": "test"}},
|
||||
)
|
||||
|
||||
# Verify structured response is correct
|
||||
assert "structured_response" in result
|
||||
structured = result["structured_response"]
|
||||
assert structured.temperature == 75.0
|
||||
assert structured.conditions == "rainy"
|
||||
|
||||
# Verify messages include the retry feedback
|
||||
messages = result["messages"]
|
||||
human_messages = [m for m in messages if isinstance(m, HumanMessage)]
|
||||
|
||||
# Should have at least 2 human messages: initial + retry feedback
|
||||
assert len(human_messages) >= 2
|
||||
|
||||
# The retry feedback message should contain error information
|
||||
retry_message = human_messages[-1]
|
||||
assert "Error:" in retry_message.content
|
||||
assert "Please try again" in retry_message.content
|
||||
@@ -610,6 +610,35 @@ class TestResponseFormatAsToolStrategy:
|
||||
)
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
|
||||
def test_validation_error_with_invalid_response(self) -> None:
|
||||
"""Test that StructuredOutputValidationError is raised when tool strategy receives invalid response."""
|
||||
tool_calls = [
|
||||
[
|
||||
{
|
||||
"name": "WeatherBaseModel",
|
||||
"id": "1",
|
||||
"args": {"invalid_field": "wrong_data", "another_bad_field": 123},
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel(tool_calls=tool_calls)
|
||||
|
||||
agent = create_agent(
|
||||
model,
|
||||
[],
|
||||
response_format=ToolStrategy(
|
||||
WeatherBaseModel,
|
||||
handle_errors=False, # Disable retry to ensure error is raised
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
StructuredOutputValidationError,
|
||||
match=".*WeatherBaseModel.*",
|
||||
):
|
||||
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
|
||||
class TestResponseFormatAsProviderStrategy:
|
||||
def test_pydantic_model(self) -> None:
|
||||
@@ -630,6 +659,28 @@ class TestResponseFormatAsProviderStrategy:
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
def test_validation_error_with_invalid_response(self) -> None:
|
||||
"""Test that StructuredOutputValidationError is raised when provider strategy receives invalid response."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
]
|
||||
|
||||
# But we're using WeatherBaseModel which has different field requirements
|
||||
model = FakeToolCallingModel[dict](
|
||||
tool_calls=tool_calls,
|
||||
structured_response={"invalid": "data"}, # Wrong structure
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
StructuredOutputValidationError,
|
||||
match=".*WeatherBaseModel.*",
|
||||
):
|
||||
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
def test_dataclass(self) -> None:
|
||||
"""Test response_format as ProviderStrategy with dataclass."""
|
||||
tool_calls = [
|
||||
|
||||
Reference in New Issue
Block a user