From 905c6d7bad48a787bbceaa86d1ab4a2808a36dc4 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Sat, 4 Oct 2025 11:56:56 -0400 Subject: [PATCH] fix(langchain_v1): handle switching resposne format strategy based on model identity (#33259) Change response format strategy dynamically based on model. After this PR there are two remaining issues: - [ ] Review binding of tools used for output to ToolNode (shouldn't be required) - [ ] Update ModelRequest to also support the original schema provided by the user (to correctly support auto mode) --- .../langchain/agents/middleware_agent.py | 168 ++++++++++++------ .../unit_tests/agents/test_response_format.py | 93 +++++++++- 2 files changed, 201 insertions(+), 60 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 17c2825d0e7..ac2a0042a85 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -120,8 +120,15 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l return [] -def _supports_native_structured_output(model: str | BaseChatModel) -> bool: - """Check if a model supports native structured output.""" +def _supports_provider_strategy(model: str | BaseChatModel) -> bool: + """Check if a model supports provider-specific structured output. + + Args: + model: Model name string or BaseChatModel instance. + + Returns: + ``True`` if the model supports provider-specific structured output, ``False`` otherwise. + """ model_name: str | None = None if isinstance(model, str): model_name = model @@ -186,28 +193,25 @@ def create_agent( # noqa: PLR0915 if tools is None: tools = [] - # Setup structured output + # Convert response format and setup structured output tools + # Raw schemas are converted to ToolStrategy upfront to calculate tools during agent creation. + # If auto-detection is needed, the strategy may be replaced with ProviderStrategy later. + initial_response_format: ToolStrategy | ProviderStrategy | None + is_auto_detect: bool + if response_format is None: + initial_response_format, is_auto_detect = None, False + elif isinstance(response_format, (ToolStrategy, ProviderStrategy)): + # Preserve explicitly requested strategies + initial_response_format, is_auto_detect = response_format, False + else: + # Raw schema - convert to ToolStrategy for now (may be replaced with ProviderStrategy) + initial_response_format, is_auto_detect = ToolStrategy(schema=response_format), True + structured_output_tools: dict[str, OutputToolBinding] = {} - native_output_binding: ProviderStrategyBinding | None = None - - if response_format is not None: - if not isinstance(response_format, (ToolStrategy, ProviderStrategy)): - # Auto-detect strategy based on model capabilities - if _supports_native_structured_output(model): - response_format = ProviderStrategy(schema=response_format) - else: - response_format = ToolStrategy(schema=response_format) - - if isinstance(response_format, ToolStrategy): - # Setup tools strategy for structured output - for response_schema in response_format.schema_specs: - structured_tool_info = OutputToolBinding.from_schema_spec(response_schema) - structured_output_tools[structured_tool_info.tool.name] = structured_tool_info - elif isinstance(response_format, ProviderStrategy): - # Setup native strategy - native_output_binding = ProviderStrategyBinding.from_schema_spec( - response_format.schema_spec - ) + if isinstance(initial_response_format, ToolStrategy): + for response_schema in initial_response_format.schema_specs: + structured_tool_info = OutputToolBinding.from_schema_spec(response_schema) + structured_output_tools[structured_tool_info.tool.name] = structured_tool_info middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])] # Setup tools @@ -280,18 +284,29 @@ def create_agent( # noqa: PLR0915 context_schema=context_schema, ) - def _handle_model_output(output: AIMessage) -> dict[str, Any]: - """Handle model output including structured responses.""" - # Handle structured output with native strategy - if isinstance(response_format, ProviderStrategy): - if not output.tool_calls and native_output_binding: - structured_response = native_output_binding.parse(output) + def _handle_model_output( + output: AIMessage, effective_response_format: ResponseFormat | None + ) -> dict[str, Any]: + """Handle model output including structured responses. + + Args: + output: The AI message output from the model. + effective_response_format: The actual strategy used + (may differ from initial if auto-detected). + """ + # Handle structured output with provider strategy + if isinstance(effective_response_format, ProviderStrategy): + if not output.tool_calls: + provider_strategy_binding = ProviderStrategyBinding.from_schema_spec( + effective_response_format.schema_spec + ) + structured_response = provider_strategy_binding.parse(output) return {"messages": [output], "structured_response": structured_response} return {"messages": [output]} - # Handle structured output with tools strategy + # Handle structured output with tool strategy if ( - isinstance(response_format, ToolStrategy) + isinstance(effective_response_format, ToolStrategy) and isinstance(output, AIMessage) and output.tool_calls ): @@ -306,7 +321,7 @@ def create_agent( # noqa: PLR0915 tool_names = [tc["name"] for tc in structured_tool_calls] exception = MultipleStructuredOutputsError(tool_names) should_retry, error_message = _handle_structured_output_error( - exception, response_format + exception, effective_response_format ) if not should_retry: raise exception @@ -329,8 +344,8 @@ def create_agent( # noqa: PLR0915 structured_response = structured_tool_binding.parse(tool_call["args"]) tool_message_content = ( - response_format.tool_message_content - if response_format.tool_message_content + effective_response_format.tool_message_content + if effective_response_format.tool_message_content else f"Returning structured response: {structured_response}" ) @@ -348,7 +363,7 @@ def create_agent( # noqa: PLR0915 except Exception as exc: # noqa: BLE001 exception = StructuredOutputValidationError(tool_call["name"], exc) should_retry, error_message = _handle_structured_output_error( - exception, response_format + exception, effective_response_format ) if not should_retry: raise exception @@ -366,11 +381,20 @@ def create_agent( # noqa: PLR0915 return {"messages": [output]} - def _get_bound_model(request: ModelRequest) -> Runnable: - """Get the model with appropriate tool bindings.""" - # Get actual tool objects from tool names - tools_by_name = {t.name: t for t in default_tools} + def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]: + """Get the model with appropriate tool bindings. + Performs auto-detection of strategy if needed based on model capabilities. + + Args: + request: The model request containing model, tools, and response format. + + Returns: + Tuple of (bound_model, effective_response_format) where ``effective_response_format`` + is the actual strategy used (may differ from initial if auto-detected). + """ + # Validate requested tools are available + tools_by_name = {t.name: t for t in default_tools} unknown_tools = [name for name in request.tools if name not in tools_by_name] if unknown_tools: available_tools = sorted(tools_by_name.keys()) @@ -389,23 +413,49 @@ def create_agent( # noqa: PLR0915 requested_tools = [tools_by_name[name] for name in request.tools] - if isinstance(response_format, ProviderStrategy): - # Use native structured output - kwargs = response_format.to_model_kwargs() - return request.model.bind_tools( - requested_tools, strict=True, **kwargs, **request.model_settings + # Determine effective response format (auto-detect if needed) + effective_response_format: ResponseFormat | None = request.response_format + if ( + # User provided raw schema - auto-detect best strategy based on model + is_auto_detect + and isinstance(request.response_format, ToolStrategy) + and + # Model supports provider strategy - use it instead + _supports_provider_strategy(request.model) + ): + effective_response_format = ProviderStrategy(schema=response_format) # type: ignore[arg-type] + # else: keep ToolStrategy from initial conversion + + # Bind model based on effective response format + if isinstance(effective_response_format, ProviderStrategy): + # Use provider-specific structured output + kwargs = effective_response_format.to_model_kwargs() + return ( + request.model.bind_tools( + requested_tools, strict=True, **kwargs, **request.model_settings + ), + effective_response_format, ) - if isinstance(response_format, ToolStrategy): + + if isinstance(effective_response_format, ToolStrategy): + # Force tool use if we have structured output tools tool_choice = "any" if structured_output_tools else request.tool_choice - return request.model.bind_tools( - requested_tools, tool_choice=tool_choice, **request.model_settings + return ( + request.model.bind_tools( + requested_tools, tool_choice=tool_choice, **request.model_settings + ), + effective_response_format, ) - # Standard model binding + + # No structured output - standard model binding if requested_tools: - return request.model.bind_tools( - requested_tools, tool_choice=request.tool_choice, **request.model_settings + return ( + request.model.bind_tools( + requested_tools, tool_choice=request.tool_choice, **request.model_settings + ), + None, ) - return request.model.bind(**request.model_settings) + return request.model.bind(**request.model_settings), None def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: """Sync model request handler with sequential middleware processing.""" @@ -413,7 +463,7 @@ def create_agent( # noqa: PLR0915 model=model, tools=[t.name for t in default_tools], system_prompt=system_prompt, - response_format=response_format, + response_format=initial_response_format, messages=state["messages"], tool_choice=None, ) @@ -431,8 +481,8 @@ def create_agent( # noqa: PLR0915 ) raise TypeError(msg) - # Get the final model and messages - model_ = _get_bound_model(request) + # Get the bound model (with auto-detection if needed) + model_, effective_response_format = _get_bound_model(request) messages = request.messages if request.system_prompt: messages = [SystemMessage(request.system_prompt), *messages] @@ -441,7 +491,7 @@ def create_agent( # noqa: PLR0915 return { "thread_model_call_count": state.get("thread_model_call_count", 0) + 1, "run_model_call_count": state.get("run_model_call_count", 0) + 1, - **_handle_model_output(output), + **_handle_model_output(output, effective_response_format), } async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: @@ -450,7 +500,7 @@ def create_agent( # noqa: PLR0915 model=model, tools=[t.name for t in default_tools], system_prompt=system_prompt, - response_format=response_format, + response_format=initial_response_format, messages=state["messages"], tool_choice=None, ) @@ -459,8 +509,8 @@ def create_agent( # noqa: PLR0915 for m in middleware_w_modify_model_request: await m.amodify_model_request(request, state, runtime) - # Get the final model and messages - model_ = _get_bound_model(request) + # Get the bound model (with auto-detection if needed) + model_, effective_response_format = _get_bound_model(request) messages = request.messages if request.system_prompt: messages = [SystemMessage(request.system_prompt), *messages] @@ -469,7 +519,7 @@ def create_agent( # noqa: PLR0915 return { "thread_model_call_count": state.get("thread_model_call_count", 0) + 1, "run_model_call_count": state.get("run_model_call_count", 0) + 1, - **_handle_model_output(output), + **_handle_model_output(output, effective_response_format), } # Use sync or async based on model capabilities diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py index 8b1c85a0566..8d2affb266e 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py @@ -1,9 +1,11 @@ """Test suite for create_agent with structured output response_format permutations.""" +import json + import pytest from dataclasses import dataclass -from typing import Union +from typing import Union, Sequence, Any, Callable from langchain_core.messages import HumanMessage from langchain.agents import create_agent @@ -13,10 +15,16 @@ from langchain.agents.structured_output import ( StructuredOutputValidationError, ToolStrategy, ) +from langchain.tools import tool from pydantic import BaseModel, Field from typing_extensions import TypedDict +from langchain.messages import AIMessage +from langchain_core.messages import BaseMessage +from langchain_core.language_models import LanguageModelInput +from langchain_core.runnables import Runnable from tests.unit_tests.agents.model import FakeToolCallingModel +from langchain.tools import BaseTool # Test data models @@ -676,6 +684,89 @@ class TestResponseFormatAsProviderStrategy: assert len(response["messages"]) == 4 +class TestDynamicModelWithResponseFormat: + """Test response_format with middleware that modifies the model.""" + + def test_middleware_model_swap_provider_to_tool_strategy(self) -> None: + """Test that strategy resolution is deferred until after middleware modifies the model. + + Verifies that when a raw schema is provided, ``_supports_provider_strategy`` is called + on the middleware-modified model (not the original), ensuring the correct strategy is + selected based on the final model's capabilities. + """ + from unittest.mock import patch + from langchain.agents.middleware.types import AgentMiddleware, ModelRequest + from langchain_core.language_models.fake_chat_models import GenericFakeChatModel + + # Custom model that we'll use to test whether the tool strategy is applied + # correctly at runtime. + class CustomModel(GenericFakeChatModel): + tool_bindings: list[Any] = [] + + def bind_tools( + self, + tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + # Record every tool binding event. + self.tool_bindings.append(tools) + return self + + model = CustomModel( + messages=iter( + [ + # Simulate model returning structured output directly + # (this is what provider strategy would do) + json.dumps(WEATHER_DATA), + ] + ) + ) + + # Create middleware that swaps the model in the request + class ModelSwappingMiddleware(AgentMiddleware): + def modify_model_request(self, request: ModelRequest, state, runtime) -> ModelRequest: + # Replace the model with our custom test model + request.model = model + return request + + # Track which model is checked for provider strategy support + calls = [] + + def mock_supports_provider_strategy(model) -> bool: + """Track which model is checked and return True for ProviderStrategy.""" + calls.append(model) + return True + + # Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy) + # This should auto-detect strategy based on model capabilities + agent = create_agent( + model=model, + tools=[], + # Raw schema - should auto-detect strategy + response_format=WeatherBaseModel, + middleware=[ModelSwappingMiddleware()], + ) + + with patch( + "langchain.agents.middleware_agent._supports_provider_strategy", + side_effect=mock_supports_provider_strategy, + ): + response = agent.invoke({"messages": [HumanMessage("What's the weather?")]}) + + # Verify strategy resolution was deferred: check was called once during _get_bound_model + assert len(calls) == 1 + + # Verify successful parsing of JSON as structured output via ProviderStrategy + assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC + # Two messages: Human input message and AI response with JSON content + assert len(response["messages"]) == 2 + ai_message = response["messages"][1] + assert isinstance(ai_message, AIMessage) + # ProviderStrategy doesn't use tool calls - it parses content directly + assert ai_message.tool_calls == [] + assert ai_message.content == json.dumps(WEATHER_DATA) + + def test_union_of_types() -> None: """Test response_format as ProviderStrategy with Union (if supported).""" tool_calls = [