From 8aea6dd23a11daff336773dc5c8911f600ef8dcd Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 29 Oct 2025 08:41:44 -0700 Subject: [PATCH] feat: support structured output retry middleware (#33663) * attach the latest `AIMessage` to all `StructuredOutputError`s so that relevant middleware can use as desired * raise `StructuredOutputError` from `ProviderStrategy` logic in case of failed parsing (so that we can retry from middleware) * added a test suite w/ example custom middleware that retries for tool + provider strategy Long term, we could add our own opinionated structured output retry middleware, but this at least unblocks folks who want to use custom retry logic in the short term :) ```py class StructuredOutputRetryMiddleware(AgentMiddleware): """Retries model calls when structured output parsing fails.""" def __init__(self, max_retries: int) -> None: self.max_retries = max_retries def wrap_model_call( self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] ) -> ModelResponse: for attempt in range(self.max_retries + 1): try: return handler(request) except StructuredOutputError as exc: if attempt == self.max_retries: raise ai_content = exc.ai_message.content error_message = ( f"Your previous response was:\n{ai_content}\n\n" f"Error: {exc}. Please try again with a valid response." ) request.messages.append(HumanMessage(content=error_message)) ``` --- libs/langchain_v1/langchain/agents/factory.py | 19 +- .../langchain/agents/structured_output.py | 10 +- .../test_structured_output_retry.py | 369 ++++++++++++++++++ .../unit_tests/agents/test_response_format.py | 51 +++ 4 files changed, 442 insertions(+), 7 deletions(-) create mode 100644 libs/langchain_v1/tests/unit_tests/agents/middleware/test_structured_output_retry.py diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 2f9962759fc..59ac049b13f 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -33,6 +33,7 @@ from langchain.agents.structured_output import ( ProviderStrategy, ProviderStrategyBinding, ResponseFormat, + StructuredOutputError, StructuredOutputValidationError, ToolStrategy, ) @@ -797,8 +798,16 @@ def create_agent( # noqa: PLR0915 provider_strategy_binding = ProviderStrategyBinding.from_schema_spec( effective_response_format.schema_spec ) - structured_response = provider_strategy_binding.parse(output) - return {"messages": [output], "structured_response": structured_response} + try: + structured_response = provider_strategy_binding.parse(output) + except Exception as exc: # noqa: BLE001 + schema_name = getattr( + effective_response_format.schema_spec.schema, "__name__", "response_format" + ) + validation_error = StructuredOutputValidationError(schema_name, exc, output) + raise validation_error + else: + return {"messages": [output], "structured_response": structured_response} return {"messages": [output]} # Handle structured output with tool strategy @@ -812,11 +821,11 @@ def create_agent( # noqa: PLR0915 ] if structured_tool_calls: - exception: Exception | None = None + exception: StructuredOutputError | None = None if len(structured_tool_calls) > 1: # Handle multiple structured outputs error tool_names = [tc["name"] for tc in structured_tool_calls] - exception = MultipleStructuredOutputsError(tool_names) + exception = MultipleStructuredOutputsError(tool_names, output) should_retry, error_message = _handle_structured_output_error( exception, effective_response_format ) @@ -858,7 +867,7 @@ def create_agent( # noqa: PLR0915 "structured_response": structured_response, } except Exception as exc: # noqa: BLE001 - exception = StructuredOutputValidationError(tool_call["name"], exc) + exception = StructuredOutputValidationError(tool_call["name"], exc, output) should_retry, error_message = _handle_structured_output_error( exception, effective_response_format ) diff --git a/libs/langchain_v1/langchain/agents/structured_output.py b/libs/langchain_v1/langchain/agents/structured_output.py index cd6a2fd9aed..75038675807 100644 --- a/libs/langchain_v1/langchain/agents/structured_output.py +++ b/libs/langchain_v1/langchain/agents/structured_output.py @@ -34,17 +34,21 @@ SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"] class StructuredOutputError(Exception): """Base class for structured output errors.""" + ai_message: AIMessage + class MultipleStructuredOutputsError(StructuredOutputError): """Raised when model returns multiple structured output tool calls when only one is expected.""" - def __init__(self, tool_names: list[str]) -> None: + def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None: """Initialize `MultipleStructuredOutputsError`. Args: tool_names: The names of the tools called for structured output. + ai_message: The AI message that contained the invalid multiple tool calls. """ self.tool_names = tool_names + self.ai_message = ai_message super().__init__( "Model incorrectly returned multiple structured responses " @@ -55,15 +59,17 @@ class MultipleStructuredOutputsError(StructuredOutputError): class StructuredOutputValidationError(StructuredOutputError): """Raised when structured output tool call arguments fail to parse according to the schema.""" - def __init__(self, tool_name: str, source: Exception) -> None: + def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None: """Initialize `StructuredOutputValidationError`. Args: tool_name: The name of the tool that failed. source: The exception that occurred. + ai_message: The AI message that contained the invalid structured output. """ self.tool_name = tool_name self.source = source + self.ai_message = ai_message super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.") diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/test_structured_output_retry.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_structured_output_retry.py new file mode 100644 index 00000000000..a04f670ad4a --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/test_structured_output_retry.py @@ -0,0 +1,369 @@ +"""Tests for StructuredOutputRetryMiddleware functionality.""" + +from collections.abc import Callable + +import pytest +from langchain_core.messages import HumanMessage +from langchain_core.tools import tool +from langgraph.checkpoint.memory import InMemorySaver +from pydantic import BaseModel + +from langchain.agents import create_agent +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelRequest, + ModelResponse, +) +from langchain.agents.structured_output import StructuredOutputError, ToolStrategy +from tests.unit_tests.agents.model import FakeToolCallingModel + + +class StructuredOutputRetryMiddleware(AgentMiddleware): + """Retries model calls when structured output parsing fails.""" + + def __init__(self, max_retries: int) -> None: + """Initialize the structured output retry middleware. + + Args: + max_retries: Maximum number of retry attempts. + """ + self.max_retries = max_retries + + def wrap_model_call( + self, request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] + ) -> ModelResponse: + """Intercept and control model execution via handler callback. + + Args: + request: The model request containing messages and configuration. + handler: The function to call the model. + + Returns: + The model response. + + Raises: + StructuredOutputError: If max retries exceeded without success. + """ + for attempt in range(self.max_retries + 1): + try: + return handler(request) + except StructuredOutputError as exc: + if attempt == self.max_retries: + raise + + # Include both the AI message and error in a single human message + # to maintain valid chat history alternation + ai_content = exc.ai_message.content + error_message = ( + f"Your previous response was:\n{ai_content}\n\n" + f"Error: {exc}. Please try again with a valid response." + ) + request.messages.append(HumanMessage(content=error_message)) + + # This should never be reached, but satisfies type checker + return handler(request) + + +class WeatherReport(BaseModel): + """Weather report schema for testing.""" + + temperature: float + conditions: str + + +@tool +def get_weather(city: str) -> str: + """Get the weather for a given city. + + Args: + city: The city to get weather for. + + Returns: + Weather information for the city. + """ + return f"The weather in {city} is sunny and 72 degrees." + + +def test_structured_output_retry_first_attempt_invalid() -> None: + """Test structured output retry when first two attempts have invalid output.""" + # First two attempts have invalid tool arguments, third attempt succeeds + # The model will call the WeatherReport structured output tool + tool_calls = [ + # First attempt - invalid: wrong type for temperature + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "not-a-float", "conditions": "sunny"}, + } + ], + # Second attempt - invalid: missing required field + [{"name": "WeatherReport", "id": "2", "args": {"temperature": 72.5}}], + # Third attempt - valid + [ + { + "name": "WeatherReport", + "id": "3", + "args": {"temperature": 72.5, "conditions": "sunny"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=2) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("What's the weather in Tokyo?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify we got a structured response + assert "structured_response" in result + structured = result["structured_response"] + assert isinstance(structured, WeatherReport) + assert structured.temperature == 72.5 + assert structured.conditions == "sunny" + + # Verify the model was called 3 times (initial + 2 retries) + assert model.index == 3 + + +def test_structured_output_retry_exceeds_max_retries() -> None: + """Test structured output retry raises error when max retries exceeded.""" + # All three attempts return invalid arguments + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "invalid", "conditions": "sunny"}, + } + ], + [ + { + "name": "WeatherReport", + "id": "2", + "args": {"temperature": "also-invalid", "conditions": "cloudy"}, + } + ], + [ + { + "name": "WeatherReport", + "id": "3", + "args": {"temperature": "still-invalid", "conditions": "rainy"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=2) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + # No checkpointer - we expect this to fail + ) + + # Should raise StructuredOutputError after exhausting retries + with pytest.raises(StructuredOutputError): + agent.invoke( + {"messages": [HumanMessage("What's the weather in Tokyo?")]}, + ) + + # Verify the model was called 3 times (initial + 2 retries) + assert model.index == 3 + + +def test_structured_output_retry_succeeds_first_attempt() -> None: + """Test structured output retry when first attempt succeeds (no retry needed).""" + # First attempt returns valid structured output + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": 68.0, "conditions": "cloudy"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=2) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("What's the weather in Paris?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify we got a structured response + assert "structured_response" in result + structured = result["structured_response"] + assert isinstance(structured, WeatherReport) + assert structured.temperature == 68.0 + assert structured.conditions == "cloudy" + + # Verify the model was called only once + assert model.index == 1 + + +def test_structured_output_retry_validation_error() -> None: + """Test structured output retry with schema validation errors.""" + # First attempt has wrong type, second has missing field, third succeeds + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "seventy-two", "conditions": "sunny"}, + } + ], + [{"name": "WeatherReport", "id": "2", "args": {"temperature": 72.5}}], + [ + { + "name": "WeatherReport", + "id": "3", + "args": {"temperature": 72.5, "conditions": "partly cloudy"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=2) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("What's the weather in London?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify we got a structured response + assert "structured_response" in result + structured = result["structured_response"] + assert isinstance(structured, WeatherReport) + assert structured.temperature == 72.5 + assert structured.conditions == "partly cloudy" + + # Verify the model was called 3 times + assert model.index == 3 + + +def test_structured_output_retry_zero_retries() -> None: + """Test structured output retry with max_retries=0 (no retries allowed).""" + # First attempt returns invalid arguments + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "invalid", "conditions": "sunny"}, + } + ], + [ + { + "name": "WeatherReport", + "id": "2", + "args": {"temperature": 72.5, "conditions": "sunny"}, + } + ], # Would succeed if retried + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=0) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + # Should fail immediately without retrying + with pytest.raises(StructuredOutputError): + agent.invoke( + {"messages": [HumanMessage("What's the weather in Berlin?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify the model was called only once (no retries) + assert model.index == 1 + + +def test_structured_output_retry_preserves_messages() -> None: + """Test structured output retry preserves error feedback in messages.""" + # First attempt invalid, second succeeds + tool_calls = [ + [ + { + "name": "WeatherReport", + "id": "1", + "args": {"temperature": "invalid", "conditions": "rainy"}, + } + ], + [ + { + "name": "WeatherReport", + "id": "2", + "args": {"temperature": 75.0, "conditions": "rainy"}, + } + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + retry_middleware = StructuredOutputRetryMiddleware(max_retries=1) + + agent = create_agent( + model=model, + tools=[get_weather], + middleware=[retry_middleware], + response_format=ToolStrategy(schema=WeatherReport, handle_errors=False), + checkpointer=InMemorySaver(), + ) + + result = agent.invoke( + {"messages": [HumanMessage("What's the weather in Seattle?")]}, + {"configurable": {"thread_id": "test"}}, + ) + + # Verify structured response is correct + assert "structured_response" in result + structured = result["structured_response"] + assert structured.temperature == 75.0 + assert structured.conditions == "rainy" + + # Verify messages include the retry feedback + messages = result["messages"] + human_messages = [m for m in messages if isinstance(m, HumanMessage)] + + # Should have at least 2 human messages: initial + retry feedback + assert len(human_messages) >= 2 + + # The retry feedback message should contain error information + retry_message = human_messages[-1] + assert "Error:" in retry_message.content + assert "Please try again" in retry_message.content diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py index a7963ced16f..7df5c23463b 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py @@ -610,6 +610,35 @@ class TestResponseFormatAsToolStrategy: ) assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + def test_validation_error_with_invalid_response(self) -> None: + """Test that StructuredOutputValidationError is raised when tool strategy receives invalid response.""" + tool_calls = [ + [ + { + "name": "WeatherBaseModel", + "id": "1", + "args": {"invalid_field": "wrong_data", "another_bad_field": 123}, + }, + ], + ] + + model = FakeToolCallingModel(tool_calls=tool_calls) + + agent = create_agent( + model, + [], + response_format=ToolStrategy( + WeatherBaseModel, + handle_errors=False, # Disable retry to ensure error is raised + ), + ) + + with pytest.raises( + StructuredOutputValidationError, + match=".*WeatherBaseModel.*", + ): + agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + class TestResponseFormatAsProviderStrategy: def test_pydantic_model(self) -> None: @@ -630,6 +659,28 @@ class TestResponseFormatAsProviderStrategy: assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC assert len(response["messages"]) == 4 + def test_validation_error_with_invalid_response(self) -> None: + """Test that StructuredOutputValidationError is raised when provider strategy receives invalid response.""" + tool_calls = [ + [{"args": {}, "id": "1", "name": "get_weather"}], + ] + + # But we're using WeatherBaseModel which has different field requirements + model = FakeToolCallingModel[dict]( + tool_calls=tool_calls, + structured_response={"invalid": "data"}, # Wrong structure + ) + + agent = create_agent( + model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel) + ) + + with pytest.raises( + StructuredOutputValidationError, + match=".*WeatherBaseModel.*", + ): + agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + def test_dataclass(self) -> None: """Test response_format as ProviderStrategy with dataclass.""" tool_calls = [