Compare commits

...

11 Commits

Author SHA1 Message Date
Sydney Runkle
c16969d8ac typo 2025-10-16 15:15:51 -04:00
Sydney Runkle
038e78fa95 Merge branch 'strategy-binding' of https://github.com/langchain-ai/langchain into strategy-binding 2025-10-16 14:24:58 -04:00
Sydney Runkle
e18a21a8b0 tests' 2025-10-16 14:24:54 -04:00
Sydney Runkle
c8b1162067 Apply suggestions from code review 2025-10-16 13:40:11 -04:00
Sydney Runkle
d8d94e6c95 Apply suggestions from code review 2025-10-16 13:38:45 -04:00
Sydney Runkle
e4ab6896a7 another pass 2025-10-16 13:33:46 -04:00
Sydney Runkle
6efea75d1c linting 2025-10-16 13:24:02 -04:00
Sydney Runkle
5155fbe072 better structure 2025-10-16 13:20:47 -04:00
Sydney Runkle
7157c2f69c openai and xai 2025-10-16 13:12:14 -04:00
Sydney Runkle
c97e649bd3 linting, ofc 2025-10-16 12:44:44 -04:00
Sydney Runkle
602284dc3e provider specific kwargs 2025-10-16 12:14:48 -04:00
4 changed files with 193 additions and 60 deletions

View File

@@ -13,7 +13,6 @@ from typing import (
get_type_hints,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
from langgraph._internal._runnable import RunnableCallable
@@ -42,6 +41,7 @@ from langchain.agents.structured_output import (
ResponseFormat,
StructuredOutputValidationError,
ToolStrategy,
_supports_provider_strategy,
)
from langchain.chat_models import init_chat_model
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
@@ -49,6 +49,7 @@ from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Sequence
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.runnables import Runnable
from langgraph.cache.base import BaseCache
from langgraph.graph.state import CompiledStateGraph
@@ -347,29 +348,6 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
return []
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
"""Check if a model supports provider-specific structured output.
Args:
model: Model name string or `BaseChatModel` instance.
Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
"""
model_name: str | None = None
if isinstance(model, str):
model_name = model
elif isinstance(model, BaseChatModel):
model_name = getattr(model, "model_name", None)
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
if model_name
else False
)
def _handle_structured_output_error(
exception: Exception,
response_format: ResponseFormat,
@@ -932,16 +910,34 @@ def create_agent( # noqa: PLR0915
# Determine effective response format (auto-detect if needed)
effective_response_format: ResponseFormat | None
model_name: str = cast(
"str",
(
request.model
if isinstance(request.model, str)
else getattr(request.model, "model_name", "")
),
)
if isinstance(request.response_format, AutoStrategy):
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
if _supports_provider_strategy(request.model):
if _supports_provider_strategy(model_name):
# Model supports provider strategy - use it
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
else:
# Model doesn't support provider strategy - use ToolStrategy
effective_response_format = ToolStrategy(schema=request.response_format.schema)
elif isinstance(request.response_format, ProviderStrategy):
if not _supports_provider_strategy(model_name):
msg = (
f"Cannot use ProviderStrategy with {model_name}. "
"Supported models: OpenAI (gpt-5, gpt-4.1, gpt-oss, o3-pro, o3-mini), "
"X.AI (Grok). "
"Consider using a raw schema (which auto-selects the best strategy) or "
"explicitly use `ToolStrategy` for unsupported providers."
)
raise ValueError(msg)
effective_response_format = request.response_format
else:
# User explicitly specified a strategy - preserve it
effective_response_format = request.response_format
# Build final tools list including structured output tools
@@ -957,12 +953,9 @@ def create_agent( # noqa: PLR0915
if isinstance(effective_response_format, ProviderStrategy):
# Use provider-specific structured output
kwargs = effective_response_format.to_model_kwargs()
return (
request.model.bind_tools(
final_tools, strict=True, **kwargs, **request.model_settings
),
effective_response_format,
)
return request.model.bind_tools(
final_tools, **kwargs, **request.model_settings
), effective_response_format
if isinstance(effective_response_format, ToolStrategy):
# Current implementation requires that tools used for structured output

View File

@@ -31,6 +31,23 @@ SchemaT = TypeVar("SchemaT")
SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
def _supports_provider_strategy(model_name: str) -> bool:
"""Check if a model supports provider-specific structured output.
Args:
model_name: Model name string.
Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
"""
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
if model_name
else False
)
class StructuredOutputError(Exception):
"""Base class for structured output errors."""
@@ -238,7 +255,56 @@ class ToolStrategy(Generic[SchemaT]):
@dataclass(init=False)
class ProviderStrategy(Generic[SchemaT]):
"""Use the model provider's native structured output method."""
"""Use the model provider's native structured output method.
`ProviderStrategy` uses provider-specific structured output APIs that enforce
JSON schema validation at the model level. This provides stronger guarantees
than tool-based approaches but is only supported by certain providers.
Supported Providers:
- **OpenAI**: All models that support structured outputs (requires `strict=True`)
- **X.AI (Grok)**: All models that support structured outputs (requires `strict=True`)
Important:
When using `ProviderStrategy`, the agent will validate at runtime that the
model provider is supported. If you're using an unsupported provider, consider:
- Using a **raw schema** (recommended): Automatically selects the best strategy
based on model capabilities
- Using **`ToolStrategy`**: Explicitly use tool-based structured output for any
provider
Example:
```python
from langchain.agents import create_agent
from langchain.agents.structured_output import ProviderStrategy
from pydantic import BaseModel
class WeatherResponse(BaseModel):
temperature: float
condition: str
# Explicitly use provider strategy (only for OpenAI/Grok)
agent = create_agent(
model="openai:gpt-4", tools=[], response_format=ProviderStrategy(WeatherResponse)
)
# Or use raw schema for automatic strategy selection (recommended)
# This will auto-select ProviderStrategy for OpenAI/Grok, ToolStrategy for others
agent = create_agent(
model="openai:gpt-4",
tools=[],
response_format=WeatherResponse, # Auto-selects best strategy
)
```
Note:
`ProviderStrategy` can be used with middleware that changes the model at runtime.
Validation occurs after the model is resolved, allowing dynamic model selection
while ensuring provider compatibility.
"""
schema: type[SchemaT]
"""Schema for native mode."""
@@ -255,9 +321,19 @@ class ProviderStrategy(Generic[SchemaT]):
self.schema_spec = _SchemaSpec(schema)
def to_model_kwargs(self) -> dict[str, Any]:
"""Convert to kwargs to bind to a model to force structured output."""
# OpenAI:
# - see https://platform.openai.com/docs/guides/structured-outputs
"""Convert to kwargs to bind to a model to force structured output.
Args:
model: The model instance to check provider for conditional `strict` param.
Returns:
Model kwargs with `response_format` and optionally `strict`.
"""
# Provider-specific structured output:
# - OpenAI: https://platform.openai.com/docs/guides/structured-outputs
# - Uses strict=True for schema validation
# - X.AI (Grok): https://docs.x.ai/docs/guides/structured-outputs
# - Uses strict=True for schema validation (required)
response_format = {
"type": "json_schema",
"json_schema": {
@@ -265,7 +341,8 @@ class ProviderStrategy(Generic[SchemaT]):
"schema": self.schema_spec.json_schema,
},
}
return {"response_format": response_format}
return {"response_format": response_format, "strict": True}
@dataclass

View File

@@ -11,6 +11,7 @@ from typing import (
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.language_models.base import LangSmithParams
from langchain_core.messages import (
AIMessage,
BaseMessage,
@@ -29,6 +30,7 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
structured_response: StructuredResponseT | None = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"
model_name: str = "fake-model"
def _generate(
self,

View File

@@ -619,7 +619,81 @@ class TestResponseFormatAsProviderStrategy:
]
model = FakeToolCallingModel[WeatherBaseModel](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
tool_calls=tool_calls,
structured_response=EXPECTED_WEATHER_PYDANTIC,
model_name="gpt-4.1",
)
agent = create_agent(
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
assert len(response["messages"]) == 4
def test_unsupported_model_raises_error(self) -> None:
"""Test that ProviderStrategy raises ValueError for unsupported models."""
tool_calls = [
[{"args": {}, "id": "1", "name": "get_weather"}],
]
# Use a model name that doesn't support provider strategy
model = FakeToolCallingModel[WeatherBaseModel](
tool_calls=tool_calls,
structured_response=EXPECTED_WEATHER_PYDANTIC,
model_name="claude-3-5-sonnet",
)
agent = create_agent(
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
)
with pytest.raises(
ValueError,
match=(
r"Cannot use ProviderStrategy with claude-3-5-sonnet\. "
r"Supported models: OpenAI \(gpt-5, gpt-4\.1, gpt-oss, o3-pro, o3-mini\), "
r"X\.AI \(Grok\)\. "
r"Consider using a raw schema \(which auto-selects the best strategy\) or "
r"explicitly use `ToolStrategy` for unsupported providers\."
),
):
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
def test_supported_openai_models(self) -> None:
"""Test that ProviderStrategy works with all supported OpenAI model variants."""
supported_models = ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"]
for model_name in supported_models:
tool_calls = [
[{"args": {}, "id": "1", "name": "get_weather"}],
]
model = FakeToolCallingModel[WeatherBaseModel](
tool_calls=tool_calls,
structured_response=EXPECTED_WEATHER_PYDANTIC,
model_name=model_name,
)
agent = create_agent(
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
assert len(response["messages"]) == 4
def test_supported_grok_model(self) -> None:
"""Test that ProviderStrategy works with Grok models."""
tool_calls = [
[{"args": {}, "id": "1", "name": "get_weather"}],
]
model = FakeToolCallingModel[WeatherBaseModel](
tool_calls=tool_calls,
structured_response=EXPECTED_WEATHER_PYDANTIC,
model_name="grok-beta",
)
agent = create_agent(
@@ -637,7 +711,9 @@ class TestResponseFormatAsProviderStrategy:
]
model = FakeToolCallingModel[WeatherDataclass](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
tool_calls=tool_calls,
structured_response=EXPECTED_WEATHER_DATACLASS,
model_name="gpt-4.1",
)
agent = create_agent(
@@ -657,7 +733,7 @@ class TestResponseFormatAsProviderStrategy:
]
model = FakeToolCallingModel[WeatherTypedDict](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
)
agent = create_agent(
@@ -675,7 +751,7 @@ class TestResponseFormatAsProviderStrategy:
]
model = FakeToolCallingModel[dict](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
)
agent = create_agent(
@@ -697,13 +773,13 @@ class TestDynamicModelWithResponseFormat:
on the middleware-modified model (not the original), ensuring the correct strategy is
selected based on the final model's capabilities.
"""
from unittest.mock import patch
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
# Custom model that we'll use to test whether the tool strategy is applied
# correctly at runtime.
# Custom model that we'll use to test whether the provider strategy is applied
# correctly at runtime. Use a model_name that supports provider strategy.
class CustomModel(GenericFakeChatModel):
model_name: str = "gpt-4.1"
tool_bindings: list[Any] = []
def bind_tools(
@@ -736,14 +812,6 @@ class TestDynamicModelWithResponseFormat:
request.model = model
return handler(request)
# Track which model is checked for provider strategy support
calls = []
def mock_supports_provider_strategy(model) -> bool:
"""Track which model is checked and return True for ProviderStrategy."""
calls.append(model)
return True
# Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy)
# This should auto-detect strategy based on model capabilities
agent = create_agent(
@@ -754,14 +822,7 @@ class TestDynamicModelWithResponseFormat:
middleware=[ModelSwappingMiddleware()],
)
with patch(
"langchain.agents.factory._supports_provider_strategy",
side_effect=mock_supports_provider_strategy,
):
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
# Verify strategy resolution was deferred: check was called once during _get_bound_model
assert len(calls) == 1
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
# Verify successful parsing of JSON as structured output via ProviderStrategy
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC