better structure

This commit is contained in:
Sydney Runkle
2025-10-16 13:20:47 -04:00
parent 7157c2f69c
commit 5155fbe072
4 changed files with 47 additions and 78 deletions

View File

@@ -42,6 +42,7 @@ from langchain.agents.structured_output import (
ResponseFormat,
StructuredOutputValidationError,
ToolStrategy,
_supports_provider_strategy,
)
from langchain.chat_models import init_chat_model
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
@@ -347,29 +348,6 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
return []
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
"""Check if a model supports provider-specific structured output.
Args:
model: Model name string or `BaseChatModel` instance.
Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
"""
model_name: str | None = None
if isinstance(model, str):
model_name = model
elif isinstance(model, BaseChatModel):
model_name = getattr(model, "model_name", None)
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
if model_name
else False
)
def _handle_structured_output_error(
exception: Exception,
response_format: ResponseFormat,

View File

@@ -23,6 +23,7 @@ from typing_extensions import Self, is_typeddict
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
# Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
@@ -31,6 +32,30 @@ SchemaT = TypeVar("SchemaT")
SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
def _supports_provider_strategy(model: str | Any) -> bool:
"""Check if a model supports provider-specific structured output.
Args:
model: Model name string or `BaseChatModel` instance.
Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
"""
model_name: str | None = None
if isinstance(model, str):
model_name = model
else:
# Try to get model_name attribute from model instance
model_name = getattr(model, "model_name", None)
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
if model_name
else False
)
class StructuredOutputError(Exception):
"""Base class for structured output errors."""
@@ -329,11 +354,10 @@ class ProviderStrategy(Generic[SchemaT]):
# Both providers require strict=True for structured output
kwargs: dict[str, Any] = {"response_format": response_format}
if model is not None and hasattr(model, "_get_ls_params"):
ls_params = model._get_ls_params()
provider = ls_params.get("ls_provider", "").lower()
if provider in ("openai", "xai"):
kwargs["strict"] = True
# Use _supports_provider_strategy to determine if we should set strict=True
# This checks model name patterns for OpenAI and Grok models
if model is not None and _supports_provider_strategy(model):
kwargs["strict"] = True
return kwargs

View File

@@ -31,6 +31,7 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"
ls_provider: str = "openai"
model_name: str = "fake-model"
def _generate(
self,
@@ -54,6 +55,7 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
tool_calls = []
if is_native and not tool_calls:
content_obj = {}
if isinstance(self.structured_response, BaseModel):
content_obj = self.structured_response.model_dump()
elif is_dataclass(self.structured_response):

View File

@@ -619,7 +619,7 @@ class TestResponseFormatAsProviderStrategy:
]
model = FakeToolCallingModel[WeatherBaseModel](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC, model_name="gpt-4.1"
)
agent = create_agent(
@@ -637,7 +637,7 @@ class TestResponseFormatAsProviderStrategy:
]
model = FakeToolCallingModel[WeatherDataclass](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS, model_name="gpt-4.1"
)
agent = create_agent(
@@ -657,7 +657,7 @@ class TestResponseFormatAsProviderStrategy:
]
model = FakeToolCallingModel[WeatherTypedDict](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
)
agent = create_agent(
@@ -675,7 +675,7 @@ class TestResponseFormatAsProviderStrategy:
]
model = FakeToolCallingModel[dict](
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
)
agent = create_agent(
@@ -697,13 +697,13 @@ class TestDynamicModelWithResponseFormat:
on the middleware-modified model (not the original), ensuring the correct strategy is
selected based on the final model's capabilities.
"""
from unittest.mock import patch
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
# Custom model that we'll use to test whether the tool strategy is applied
# correctly at runtime.
# Custom model that we'll use to test whether the provider strategy is applied
# correctly at runtime. Use a model_name that supports provider strategy.
class CustomModel(GenericFakeChatModel):
model_name: str = "gpt-4.1"
tool_bindings: list[Any] = []
def bind_tools(
@@ -715,12 +715,6 @@ class TestDynamicModelWithResponseFormat:
self.tool_bindings.append(tools)
return self
def _get_ls_params(self, **kwargs: Any):
"""Return OpenAI as provider to pass ProviderStrategy validation."""
from langchain_core.language_models.base import LangSmithParams
return LangSmithParams(ls_provider="openai", ls_model_type="chat")
model = CustomModel(
messages=iter(
[
@@ -742,14 +736,6 @@ class TestDynamicModelWithResponseFormat:
request.model = model
return handler(request)
# Track which model is checked for provider strategy support
calls = []
def mock_supports_provider_strategy(model) -> bool:
"""Track which model is checked and return True for ProviderStrategy."""
calls.append(model)
return True
# Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy)
# This should auto-detect strategy based on model capabilities
agent = create_agent(
@@ -760,14 +746,7 @@ class TestDynamicModelWithResponseFormat:
middleware=[ModelSwappingMiddleware()],
)
with patch(
"langchain.agents.factory._supports_provider_strategy",
side_effect=mock_supports_provider_strategy,
):
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
# Verify strategy resolution was deferred: check was called once during _get_bound_model
assert len(calls) == 1
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
# Verify successful parsing of JSON as structured output via ProviderStrategy
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
@@ -811,17 +790,14 @@ def test_union_of_types() -> None:
def test_provider_strategy_strict_only_for_openai() -> None:
"""Test that strict=True is set for OpenAI and Grok models in ProviderStrategy."""
from langchain.agents.structured_output import ProviderStrategy
from langchain_core.language_models.base import LangSmithParams
# Create a mock OpenAI model
# Create a mock OpenAI model with model_name
class MockOpenAIModel:
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
return LangSmithParams(ls_provider="openai", ls_model_type="chat")
model_name: str = "gpt-4.1"
# Create a mock Grok/X.AI model
# Create a mock Grok/X.AI model with model_name
class MockGrokModel:
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
return LangSmithParams(ls_provider="xai", ls_model_type="chat")
model_name: str = "grok-beta"
provider_strategy = ProviderStrategy(WeatherBaseModel)
@@ -843,25 +819,14 @@ def test_provider_strategy_strict_only_for_openai() -> None:
def test_provider_strategy_validation() -> None:
"""Test that ProviderStrategy validates provider support at agent invocation time."""
from langchain.agents.structured_output import ProviderStrategy
from langchain_core.language_models.base import LangSmithParams
# Create a mock model from an unsupported provider (e.g., Anthropic)
# Use a model_name that doesn't match any supported patterns
class MockAnthropicModel(FakeToolCallingModel):
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
return LangSmithParams(ls_provider="anthropic", ls_model_type="chat")
# Create a mock model without _get_ls_params
class MockModelNoLSParams(FakeToolCallingModel):
def _get_ls_params(self, **kwargs: Any):
msg = "This model doesn't support _get_ls_params"
raise AttributeError(msg)
model_name: str = "claude-3-5-sonnet-20241022"
# Test unsupported provider: should raise ValueError when invoking agent
anthropic_model = MockAnthropicModel(tool_calls=[[]])
agent = create_agent(anthropic_model, [], response_format=ProviderStrategy(WeatherBaseModel))
with pytest.raises(ValueError, match="does not support provider 'anthropic'"):
with pytest.raises(ValueError, match="ProviderStrategy does not support this model"):
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
# Test model without proper _get_ls_params: still works if model has the method
# (validation checks hasattr and calls it)
# We can't easily test the "no _get_ls_params" case without breaking BaseChatModel