fix(langchain_v1): handle switching resposne format strategy based on model identity (#33259)

Change response format strategy dynamically based on model.

After this PR there are two remaining issues:

- [ ] Review binding of tools used for output to ToolNode (shouldn't be
required)
- [ ] Update ModelRequest to also support the original schema provided
by the user (to correctly support auto mode)
This commit is contained in:
Eugene Yurtsev
2025-10-04 11:56:56 -04:00
committed by GitHub
parent acd1aa813c
commit 905c6d7bad
2 changed files with 201 additions and 60 deletions

View File

@@ -120,8 +120,15 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
return []
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
"""Check if a model supports native structured output."""
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
"""Check if a model supports provider-specific structured output.
Args:
model: Model name string or BaseChatModel instance.
Returns:
``True`` if the model supports provider-specific structured output, ``False`` otherwise.
"""
model_name: str | None = None
if isinstance(model, str):
model_name = model
@@ -186,28 +193,25 @@ def create_agent( # noqa: PLR0915
if tools is None:
tools = []
# Setup structured output
# Convert response format and setup structured output tools
# Raw schemas are converted to ToolStrategy upfront to calculate tools during agent creation.
# If auto-detection is needed, the strategy may be replaced with ProviderStrategy later.
initial_response_format: ToolStrategy | ProviderStrategy | None
is_auto_detect: bool
if response_format is None:
initial_response_format, is_auto_detect = None, False
elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
# Preserve explicitly requested strategies
initial_response_format, is_auto_detect = response_format, False
else:
# Raw schema - convert to ToolStrategy for now (may be replaced with ProviderStrategy)
initial_response_format, is_auto_detect = ToolStrategy(schema=response_format), True
structured_output_tools: dict[str, OutputToolBinding] = {}
native_output_binding: ProviderStrategyBinding | None = None
if response_format is not None:
if not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
# Auto-detect strategy based on model capabilities
if _supports_native_structured_output(model):
response_format = ProviderStrategy(schema=response_format)
else:
response_format = ToolStrategy(schema=response_format)
if isinstance(response_format, ToolStrategy):
# Setup tools strategy for structured output
for response_schema in response_format.schema_specs:
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
elif isinstance(response_format, ProviderStrategy):
# Setup native strategy
native_output_binding = ProviderStrategyBinding.from_schema_spec(
response_format.schema_spec
)
if isinstance(initial_response_format, ToolStrategy):
for response_schema in initial_response_format.schema_specs:
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
# Setup tools
@@ -280,18 +284,29 @@ def create_agent( # noqa: PLR0915
context_schema=context_schema,
)
def _handle_model_output(output: AIMessage) -> dict[str, Any]:
"""Handle model output including structured responses."""
# Handle structured output with native strategy
if isinstance(response_format, ProviderStrategy):
if not output.tool_calls and native_output_binding:
structured_response = native_output_binding.parse(output)
def _handle_model_output(
output: AIMessage, effective_response_format: ResponseFormat | None
) -> dict[str, Any]:
"""Handle model output including structured responses.
Args:
output: The AI message output from the model.
effective_response_format: The actual strategy used
(may differ from initial if auto-detected).
"""
# Handle structured output with provider strategy
if isinstance(effective_response_format, ProviderStrategy):
if not output.tool_calls:
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
effective_response_format.schema_spec
)
structured_response = provider_strategy_binding.parse(output)
return {"messages": [output], "structured_response": structured_response}
return {"messages": [output]}
# Handle structured output with tools strategy
# Handle structured output with tool strategy
if (
isinstance(response_format, ToolStrategy)
isinstance(effective_response_format, ToolStrategy)
and isinstance(output, AIMessage)
and output.tool_calls
):
@@ -306,7 +321,7 @@ def create_agent( # noqa: PLR0915
tool_names = [tc["name"] for tc in structured_tool_calls]
exception = MultipleStructuredOutputsError(tool_names)
should_retry, error_message = _handle_structured_output_error(
exception, response_format
exception, effective_response_format
)
if not should_retry:
raise exception
@@ -329,8 +344,8 @@ def create_agent( # noqa: PLR0915
structured_response = structured_tool_binding.parse(tool_call["args"])
tool_message_content = (
response_format.tool_message_content
if response_format.tool_message_content
effective_response_format.tool_message_content
if effective_response_format.tool_message_content
else f"Returning structured response: {structured_response}"
)
@@ -348,7 +363,7 @@ def create_agent( # noqa: PLR0915
except Exception as exc: # noqa: BLE001
exception = StructuredOutputValidationError(tool_call["name"], exc)
should_retry, error_message = _handle_structured_output_error(
exception, response_format
exception, effective_response_format
)
if not should_retry:
raise exception
@@ -366,11 +381,20 @@ def create_agent( # noqa: PLR0915
return {"messages": [output]}
def _get_bound_model(request: ModelRequest) -> Runnable:
"""Get the model with appropriate tool bindings."""
# Get actual tool objects from tool names
tools_by_name = {t.name: t for t in default_tools}
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
"""Get the model with appropriate tool bindings.
Performs auto-detection of strategy if needed based on model capabilities.
Args:
request: The model request containing model, tools, and response format.
Returns:
Tuple of (bound_model, effective_response_format) where ``effective_response_format``
is the actual strategy used (may differ from initial if auto-detected).
"""
# Validate requested tools are available
tools_by_name = {t.name: t for t in default_tools}
unknown_tools = [name for name in request.tools if name not in tools_by_name]
if unknown_tools:
available_tools = sorted(tools_by_name.keys())
@@ -389,23 +413,49 @@ def create_agent( # noqa: PLR0915
requested_tools = [tools_by_name[name] for name in request.tools]
if isinstance(response_format, ProviderStrategy):
# Use native structured output
kwargs = response_format.to_model_kwargs()
return request.model.bind_tools(
requested_tools, strict=True, **kwargs, **request.model_settings
# Determine effective response format (auto-detect if needed)
effective_response_format: ResponseFormat | None = request.response_format
if (
# User provided raw schema - auto-detect best strategy based on model
is_auto_detect
and isinstance(request.response_format, ToolStrategy)
and
# Model supports provider strategy - use it instead
_supports_provider_strategy(request.model)
):
effective_response_format = ProviderStrategy(schema=response_format) # type: ignore[arg-type]
# else: keep ToolStrategy from initial conversion
# Bind model based on effective response format
if isinstance(effective_response_format, ProviderStrategy):
# Use provider-specific structured output
kwargs = effective_response_format.to_model_kwargs()
return (
request.model.bind_tools(
requested_tools, strict=True, **kwargs, **request.model_settings
),
effective_response_format,
)
if isinstance(response_format, ToolStrategy):
if isinstance(effective_response_format, ToolStrategy):
# Force tool use if we have structured output tools
tool_choice = "any" if structured_output_tools else request.tool_choice
return request.model.bind_tools(
requested_tools, tool_choice=tool_choice, **request.model_settings
return (
request.model.bind_tools(
requested_tools, tool_choice=tool_choice, **request.model_settings
),
effective_response_format,
)
# Standard model binding
# No structured output - standard model binding
if requested_tools:
return request.model.bind_tools(
requested_tools, tool_choice=request.tool_choice, **request.model_settings
return (
request.model.bind_tools(
requested_tools, tool_choice=request.tool_choice, **request.model_settings
),
None,
)
return request.model.bind(**request.model_settings)
return request.model.bind(**request.model_settings), None
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
@@ -413,7 +463,7 @@ def create_agent( # noqa: PLR0915
model=model,
tools=[t.name for t in default_tools],
system_prompt=system_prompt,
response_format=response_format,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
)
@@ -431,8 +481,8 @@ def create_agent( # noqa: PLR0915
)
raise TypeError(msg)
# Get the final model and messages
model_ = _get_bound_model(request)
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
@@ -441,7 +491,7 @@ def create_agent( # noqa: PLR0915
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output),
**_handle_model_output(output, effective_response_format),
}
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
@@ -450,7 +500,7 @@ def create_agent( # noqa: PLR0915
model=model,
tools=[t.name for t in default_tools],
system_prompt=system_prompt,
response_format=response_format,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
)
@@ -459,8 +509,8 @@ def create_agent( # noqa: PLR0915
for m in middleware_w_modify_model_request:
await m.amodify_model_request(request, state, runtime)
# Get the final model and messages
model_ = _get_bound_model(request)
# Get the bound model (with auto-detection if needed)
model_, effective_response_format = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
@@ -469,7 +519,7 @@ def create_agent( # noqa: PLR0915
return {
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
**_handle_model_output(output),
**_handle_model_output(output, effective_response_format),
}
# Use sync or async based on model capabilities

View File

@@ -1,9 +1,11 @@
"""Test suite for create_agent with structured output response_format permutations."""
import json
import pytest
from dataclasses import dataclass
from typing import Union
from typing import Union, Sequence, Any, Callable
from langchain_core.messages import HumanMessage
from langchain.agents import create_agent
@@ -13,10 +15,16 @@ from langchain.agents.structured_output import (
StructuredOutputValidationError,
ToolStrategy,
)
from langchain.tools import tool
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from langchain.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.language_models import LanguageModelInput
from langchain_core.runnables import Runnable
from tests.unit_tests.agents.model import FakeToolCallingModel
from langchain.tools import BaseTool
# Test data models
@@ -676,6 +684,89 @@ class TestResponseFormatAsProviderStrategy:
assert len(response["messages"]) == 4
class TestDynamicModelWithResponseFormat:
"""Test response_format with middleware that modifies the model."""
def test_middleware_model_swap_provider_to_tool_strategy(self) -> None:
"""Test that strategy resolution is deferred until after middleware modifies the model.
Verifies that when a raw schema is provided, ``_supports_provider_strategy`` is called
on the middleware-modified model (not the original), ensuring the correct strategy is
selected based on the final model's capabilities.
"""
from unittest.mock import patch
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
# Custom model that we'll use to test whether the tool strategy is applied
# correctly at runtime.
class CustomModel(GenericFakeChatModel):
tool_bindings: list[Any] = []
def bind_tools(
self,
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
# Record every tool binding event.
self.tool_bindings.append(tools)
return self
model = CustomModel(
messages=iter(
[
# Simulate model returning structured output directly
# (this is what provider strategy would do)
json.dumps(WEATHER_DATA),
]
)
)
# Create middleware that swaps the model in the request
class ModelSwappingMiddleware(AgentMiddleware):
def modify_model_request(self, request: ModelRequest, state, runtime) -> ModelRequest:
# Replace the model with our custom test model
request.model = model
return request
# Track which model is checked for provider strategy support
calls = []
def mock_supports_provider_strategy(model) -> bool:
"""Track which model is checked and return True for ProviderStrategy."""
calls.append(model)
return True
# Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy)
# This should auto-detect strategy based on model capabilities
agent = create_agent(
model=model,
tools=[],
# Raw schema - should auto-detect strategy
response_format=WeatherBaseModel,
middleware=[ModelSwappingMiddleware()],
)
with patch(
"langchain.agents.middleware_agent._supports_provider_strategy",
side_effect=mock_supports_provider_strategy,
):
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
# Verify strategy resolution was deferred: check was called once during _get_bound_model
assert len(calls) == 1
# Verify successful parsing of JSON as structured output via ProviderStrategy
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
# Two messages: Human input message and AI response with JSON content
assert len(response["messages"]) == 2
ai_message = response["messages"][1]
assert isinstance(ai_message, AIMessage)
# ProviderStrategy doesn't use tool calls - it parses content directly
assert ai_message.tool_calls == []
assert ai_message.content == json.dumps(WEATHER_DATA)
def test_union_of_types() -> None:
"""Test response_format as ProviderStrategy with Union (if supported)."""
tool_calls = [