mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-05 03:48:48 +00:00
fix(langchain_v1): handle switching resposne format strategy based on model identity (#33259)
Change response format strategy dynamically based on model. After this PR there are two remaining issues: - [ ] Review binding of tools used for output to ToolNode (shouldn't be required) - [ ] Update ModelRequest to also support the original schema provided by the user (to correctly support auto mode)
This commit is contained in:
@@ -120,8 +120,15 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
return []
|
||||
|
||||
|
||||
def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
|
||||
"""Check if a model supports native structured output."""
|
||||
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model: Model name string or BaseChatModel instance.
|
||||
|
||||
Returns:
|
||||
``True`` if the model supports provider-specific structured output, ``False`` otherwise.
|
||||
"""
|
||||
model_name: str | None = None
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
@@ -186,28 +193,25 @@ def create_agent( # noqa: PLR0915
|
||||
if tools is None:
|
||||
tools = []
|
||||
|
||||
# Setup structured output
|
||||
# Convert response format and setup structured output tools
|
||||
# Raw schemas are converted to ToolStrategy upfront to calculate tools during agent creation.
|
||||
# If auto-detection is needed, the strategy may be replaced with ProviderStrategy later.
|
||||
initial_response_format: ToolStrategy | ProviderStrategy | None
|
||||
is_auto_detect: bool
|
||||
if response_format is None:
|
||||
initial_response_format, is_auto_detect = None, False
|
||||
elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
|
||||
# Preserve explicitly requested strategies
|
||||
initial_response_format, is_auto_detect = response_format, False
|
||||
else:
|
||||
# Raw schema - convert to ToolStrategy for now (may be replaced with ProviderStrategy)
|
||||
initial_response_format, is_auto_detect = ToolStrategy(schema=response_format), True
|
||||
|
||||
structured_output_tools: dict[str, OutputToolBinding] = {}
|
||||
native_output_binding: ProviderStrategyBinding | None = None
|
||||
|
||||
if response_format is not None:
|
||||
if not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
|
||||
# Auto-detect strategy based on model capabilities
|
||||
if _supports_native_structured_output(model):
|
||||
response_format = ProviderStrategy(schema=response_format)
|
||||
else:
|
||||
response_format = ToolStrategy(schema=response_format)
|
||||
|
||||
if isinstance(response_format, ToolStrategy):
|
||||
# Setup tools strategy for structured output
|
||||
for response_schema in response_format.schema_specs:
|
||||
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
|
||||
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
||||
elif isinstance(response_format, ProviderStrategy):
|
||||
# Setup native strategy
|
||||
native_output_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
response_format.schema_spec
|
||||
)
|
||||
if isinstance(initial_response_format, ToolStrategy):
|
||||
for response_schema in initial_response_format.schema_specs:
|
||||
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
|
||||
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
||||
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
||||
|
||||
# Setup tools
|
||||
@@ -280,18 +284,29 @@ def create_agent( # noqa: PLR0915
|
||||
context_schema=context_schema,
|
||||
)
|
||||
|
||||
def _handle_model_output(output: AIMessage) -> dict[str, Any]:
|
||||
"""Handle model output including structured responses."""
|
||||
# Handle structured output with native strategy
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
if not output.tool_calls and native_output_binding:
|
||||
structured_response = native_output_binding.parse(output)
|
||||
def _handle_model_output(
|
||||
output: AIMessage, effective_response_format: ResponseFormat | None
|
||||
) -> dict[str, Any]:
|
||||
"""Handle model output including structured responses.
|
||||
|
||||
Args:
|
||||
output: The AI message output from the model.
|
||||
effective_response_format: The actual strategy used
|
||||
(may differ from initial if auto-detected).
|
||||
"""
|
||||
# Handle structured output with provider strategy
|
||||
if isinstance(effective_response_format, ProviderStrategy):
|
||||
if not output.tool_calls:
|
||||
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
effective_response_format.schema_spec
|
||||
)
|
||||
structured_response = provider_strategy_binding.parse(output)
|
||||
return {"messages": [output], "structured_response": structured_response}
|
||||
return {"messages": [output]}
|
||||
|
||||
# Handle structured output with tools strategy
|
||||
# Handle structured output with tool strategy
|
||||
if (
|
||||
isinstance(response_format, ToolStrategy)
|
||||
isinstance(effective_response_format, ToolStrategy)
|
||||
and isinstance(output, AIMessage)
|
||||
and output.tool_calls
|
||||
):
|
||||
@@ -306,7 +321,7 @@ def create_agent( # noqa: PLR0915
|
||||
tool_names = [tc["name"] for tc in structured_tool_calls]
|
||||
exception = MultipleStructuredOutputsError(tool_names)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
exception, response_format
|
||||
exception, effective_response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise exception
|
||||
@@ -329,8 +344,8 @@ def create_agent( # noqa: PLR0915
|
||||
structured_response = structured_tool_binding.parse(tool_call["args"])
|
||||
|
||||
tool_message_content = (
|
||||
response_format.tool_message_content
|
||||
if response_format.tool_message_content
|
||||
effective_response_format.tool_message_content
|
||||
if effective_response_format.tool_message_content
|
||||
else f"Returning structured response: {structured_response}"
|
||||
)
|
||||
|
||||
@@ -348,7 +363,7 @@ def create_agent( # noqa: PLR0915
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exception = StructuredOutputValidationError(tool_call["name"], exc)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
exception, response_format
|
||||
exception, effective_response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise exception
|
||||
@@ -366,11 +381,20 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
return {"messages": [output]}
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> Runnable:
|
||||
"""Get the model with appropriate tool bindings."""
|
||||
# Get actual tool objects from tool names
|
||||
tools_by_name = {t.name: t for t in default_tools}
|
||||
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
|
||||
"""Get the model with appropriate tool bindings.
|
||||
|
||||
Performs auto-detection of strategy if needed based on model capabilities.
|
||||
|
||||
Args:
|
||||
request: The model request containing model, tools, and response format.
|
||||
|
||||
Returns:
|
||||
Tuple of (bound_model, effective_response_format) where ``effective_response_format``
|
||||
is the actual strategy used (may differ from initial if auto-detected).
|
||||
"""
|
||||
# Validate requested tools are available
|
||||
tools_by_name = {t.name: t for t in default_tools}
|
||||
unknown_tools = [name for name in request.tools if name not in tools_by_name]
|
||||
if unknown_tools:
|
||||
available_tools = sorted(tools_by_name.keys())
|
||||
@@ -389,23 +413,49 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
requested_tools = [tools_by_name[name] for name in request.tools]
|
||||
|
||||
if isinstance(response_format, ProviderStrategy):
|
||||
# Use native structured output
|
||||
kwargs = response_format.to_model_kwargs()
|
||||
return request.model.bind_tools(
|
||||
requested_tools, strict=True, **kwargs, **request.model_settings
|
||||
# Determine effective response format (auto-detect if needed)
|
||||
effective_response_format: ResponseFormat | None = request.response_format
|
||||
if (
|
||||
# User provided raw schema - auto-detect best strategy based on model
|
||||
is_auto_detect
|
||||
and isinstance(request.response_format, ToolStrategy)
|
||||
and
|
||||
# Model supports provider strategy - use it instead
|
||||
_supports_provider_strategy(request.model)
|
||||
):
|
||||
effective_response_format = ProviderStrategy(schema=response_format) # type: ignore[arg-type]
|
||||
# else: keep ToolStrategy from initial conversion
|
||||
|
||||
# Bind model based on effective response format
|
||||
if isinstance(effective_response_format, ProviderStrategy):
|
||||
# Use provider-specific structured output
|
||||
kwargs = effective_response_format.to_model_kwargs()
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
requested_tools, strict=True, **kwargs, **request.model_settings
|
||||
),
|
||||
effective_response_format,
|
||||
)
|
||||
if isinstance(response_format, ToolStrategy):
|
||||
|
||||
if isinstance(effective_response_format, ToolStrategy):
|
||||
# Force tool use if we have structured output tools
|
||||
tool_choice = "any" if structured_output_tools else request.tool_choice
|
||||
return request.model.bind_tools(
|
||||
requested_tools, tool_choice=tool_choice, **request.model_settings
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
requested_tools, tool_choice=tool_choice, **request.model_settings
|
||||
),
|
||||
effective_response_format,
|
||||
)
|
||||
# Standard model binding
|
||||
|
||||
# No structured output - standard model binding
|
||||
if requested_tools:
|
||||
return request.model.bind_tools(
|
||||
requested_tools, tool_choice=request.tool_choice, **request.model_settings
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
requested_tools, tool_choice=request.tool_choice, **request.model_settings
|
||||
),
|
||||
None,
|
||||
)
|
||||
return request.model.bind(**request.model_settings)
|
||||
return request.model.bind(**request.model_settings), None
|
||||
|
||||
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
@@ -413,7 +463,7 @@ def create_agent( # noqa: PLR0915
|
||||
model=model,
|
||||
tools=[t.name for t in default_tools],
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
@@ -431,8 +481,8 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
@@ -441,7 +491,7 @@ def create_agent( # noqa: PLR0915
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output),
|
||||
**_handle_model_output(output, effective_response_format),
|
||||
}
|
||||
|
||||
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
@@ -450,7 +500,7 @@ def create_agent( # noqa: PLR0915
|
||||
model=model,
|
||||
tools=[t.name for t in default_tools],
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
@@ -459,8 +509,8 @@ def create_agent( # noqa: PLR0915
|
||||
for m in middleware_w_modify_model_request:
|
||||
await m.amodify_model_request(request, state, runtime)
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
# Get the bound model (with auto-detection if needed)
|
||||
model_, effective_response_format = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
@@ -469,7 +519,7 @@ def create_agent( # noqa: PLR0915
|
||||
return {
|
||||
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
||||
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
||||
**_handle_model_output(output),
|
||||
**_handle_model_output(output, effective_response_format),
|
||||
}
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Test suite for create_agent with structured output response_format permutations."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Union, Sequence, Any, Callable
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain.agents import create_agent
|
||||
@@ -13,10 +15,16 @@ from langchain.agents.structured_output import (
|
||||
StructuredOutputValidationError,
|
||||
ToolStrategy,
|
||||
)
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.runnables import Runnable
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
# Test data models
|
||||
@@ -676,6 +684,89 @@ class TestResponseFormatAsProviderStrategy:
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
|
||||
class TestDynamicModelWithResponseFormat:
|
||||
"""Test response_format with middleware that modifies the model."""
|
||||
|
||||
def test_middleware_model_swap_provider_to_tool_strategy(self) -> None:
|
||||
"""Test that strategy resolution is deferred until after middleware modifies the model.
|
||||
|
||||
Verifies that when a raw schema is provided, ``_supports_provider_strategy`` is called
|
||||
on the middleware-modified model (not the original), ensuring the correct strategy is
|
||||
selected based on the final model's capabilities.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
|
||||
# Custom model that we'll use to test whether the tool strategy is applied
|
||||
# correctly at runtime.
|
||||
class CustomModel(GenericFakeChatModel):
|
||||
tool_bindings: list[Any] = []
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
# Record every tool binding event.
|
||||
self.tool_bindings.append(tools)
|
||||
return self
|
||||
|
||||
model = CustomModel(
|
||||
messages=iter(
|
||||
[
|
||||
# Simulate model returning structured output directly
|
||||
# (this is what provider strategy would do)
|
||||
json.dumps(WEATHER_DATA),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Create middleware that swaps the model in the request
|
||||
class ModelSwappingMiddleware(AgentMiddleware):
|
||||
def modify_model_request(self, request: ModelRequest, state, runtime) -> ModelRequest:
|
||||
# Replace the model with our custom test model
|
||||
request.model = model
|
||||
return request
|
||||
|
||||
# Track which model is checked for provider strategy support
|
||||
calls = []
|
||||
|
||||
def mock_supports_provider_strategy(model) -> bool:
|
||||
"""Track which model is checked and return True for ProviderStrategy."""
|
||||
calls.append(model)
|
||||
return True
|
||||
|
||||
# Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy)
|
||||
# This should auto-detect strategy based on model capabilities
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[],
|
||||
# Raw schema - should auto-detect strategy
|
||||
response_format=WeatherBaseModel,
|
||||
middleware=[ModelSwappingMiddleware()],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware_agent._supports_provider_strategy",
|
||||
side_effect=mock_supports_provider_strategy,
|
||||
):
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
# Verify strategy resolution was deferred: check was called once during _get_bound_model
|
||||
assert len(calls) == 1
|
||||
|
||||
# Verify successful parsing of JSON as structured output via ProviderStrategy
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
# Two messages: Human input message and AI response with JSON content
|
||||
assert len(response["messages"]) == 2
|
||||
ai_message = response["messages"][1]
|
||||
assert isinstance(ai_message, AIMessage)
|
||||
# ProviderStrategy doesn't use tool calls - it parses content directly
|
||||
assert ai_message.tool_calls == []
|
||||
assert ai_message.content == json.dumps(WEATHER_DATA)
|
||||
|
||||
|
||||
def test_union_of_types() -> None:
|
||||
"""Test response_format as ProviderStrategy with Union (if supported)."""
|
||||
tool_calls = [
|
||||
|
||||
Reference in New Issue
Block a user