mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-23 20:23:59 +00:00
better structure
This commit is contained in:
@@ -42,6 +42,7 @@ from langchain.agents.structured_output import (
|
||||
ResponseFormat,
|
||||
StructuredOutputValidationError,
|
||||
ToolStrategy,
|
||||
_supports_provider_strategy,
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
|
||||
@@ -347,29 +348,6 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
return []
|
||||
|
||||
|
||||
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model: Model name string or `BaseChatModel` instance.
|
||||
|
||||
Returns:
|
||||
`True` if the model supports provider-specific structured output, `False` otherwise.
|
||||
"""
|
||||
model_name: str | None = None
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
elif isinstance(model, BaseChatModel):
|
||||
model_name = getattr(model, "model_name", None)
|
||||
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
def _handle_structured_output_error(
|
||||
exception: Exception,
|
||||
response_format: ResponseFormat,
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing_extensions import Self, is_typeddict
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
|
||||
@@ -31,6 +32,30 @@ SchemaT = TypeVar("SchemaT")
|
||||
SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
|
||||
|
||||
|
||||
def _supports_provider_strategy(model: str | Any) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model: Model name string or `BaseChatModel` instance.
|
||||
|
||||
Returns:
|
||||
`True` if the model supports provider-specific structured output, `False` otherwise.
|
||||
"""
|
||||
model_name: str | None = None
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
else:
|
||||
# Try to get model_name attribute from model instance
|
||||
model_name = getattr(model, "model_name", None)
|
||||
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
class StructuredOutputError(Exception):
|
||||
"""Base class for structured output errors."""
|
||||
|
||||
@@ -329,11 +354,10 @@ class ProviderStrategy(Generic[SchemaT]):
|
||||
# Both providers require strict=True for structured output
|
||||
kwargs: dict[str, Any] = {"response_format": response_format}
|
||||
|
||||
if model is not None and hasattr(model, "_get_ls_params"):
|
||||
ls_params = model._get_ls_params()
|
||||
provider = ls_params.get("ls_provider", "").lower()
|
||||
if provider in ("openai", "xai"):
|
||||
kwargs["strict"] = True
|
||||
# Use _supports_provider_strategy to determine if we should set strict=True
|
||||
# This checks model name patterns for OpenAI and Grok models
|
||||
if model is not None and _supports_provider_strategy(model):
|
||||
kwargs["strict"] = True
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
|
||||
index: int = 0
|
||||
tool_style: Literal["openai", "anthropic"] = "openai"
|
||||
ls_provider: str = "openai"
|
||||
model_name: str = "fake-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -54,6 +55,7 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
|
||||
tool_calls = []
|
||||
|
||||
if is_native and not tool_calls:
|
||||
content_obj = {}
|
||||
if isinstance(self.structured_response, BaseModel):
|
||||
content_obj = self.structured_response.model_dump()
|
||||
elif is_dataclass(self.structured_response):
|
||||
|
||||
@@ -619,7 +619,7 @@ class TestResponseFormatAsProviderStrategy:
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherBaseModel](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC, model_name="gpt-4.1"
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -637,7 +637,7 @@ class TestResponseFormatAsProviderStrategy:
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherDataclass](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS, model_name="gpt-4.1"
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -657,7 +657,7 @@ class TestResponseFormatAsProviderStrategy:
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherTypedDict](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -675,7 +675,7 @@ class TestResponseFormatAsProviderStrategy:
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[dict](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -697,13 +697,13 @@ class TestDynamicModelWithResponseFormat:
|
||||
on the middleware-modified model (not the original), ensuring the correct strategy is
|
||||
selected based on the final model's capabilities.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
|
||||
# Custom model that we'll use to test whether the tool strategy is applied
|
||||
# correctly at runtime.
|
||||
# Custom model that we'll use to test whether the provider strategy is applied
|
||||
# correctly at runtime. Use a model_name that supports provider strategy.
|
||||
class CustomModel(GenericFakeChatModel):
|
||||
model_name: str = "gpt-4.1"
|
||||
tool_bindings: list[Any] = []
|
||||
|
||||
def bind_tools(
|
||||
@@ -715,12 +715,6 @@ class TestDynamicModelWithResponseFormat:
|
||||
self.tool_bindings.append(tools)
|
||||
return self
|
||||
|
||||
def _get_ls_params(self, **kwargs: Any):
|
||||
"""Return OpenAI as provider to pass ProviderStrategy validation."""
|
||||
from langchain_core.language_models.base import LangSmithParams
|
||||
|
||||
return LangSmithParams(ls_provider="openai", ls_model_type="chat")
|
||||
|
||||
model = CustomModel(
|
||||
messages=iter(
|
||||
[
|
||||
@@ -742,14 +736,6 @@ class TestDynamicModelWithResponseFormat:
|
||||
request.model = model
|
||||
return handler(request)
|
||||
|
||||
# Track which model is checked for provider strategy support
|
||||
calls = []
|
||||
|
||||
def mock_supports_provider_strategy(model) -> bool:
|
||||
"""Track which model is checked and return True for ProviderStrategy."""
|
||||
calls.append(model)
|
||||
return True
|
||||
|
||||
# Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy)
|
||||
# This should auto-detect strategy based on model capabilities
|
||||
agent = create_agent(
|
||||
@@ -760,14 +746,7 @@ class TestDynamicModelWithResponseFormat:
|
||||
middleware=[ModelSwappingMiddleware()],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"langchain.agents.factory._supports_provider_strategy",
|
||||
side_effect=mock_supports_provider_strategy,
|
||||
):
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
# Verify strategy resolution was deferred: check was called once during _get_bound_model
|
||||
assert len(calls) == 1
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
# Verify successful parsing of JSON as structured output via ProviderStrategy
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
@@ -811,17 +790,14 @@ def test_union_of_types() -> None:
|
||||
def test_provider_strategy_strict_only_for_openai() -> None:
|
||||
"""Test that strict=True is set for OpenAI and Grok models in ProviderStrategy."""
|
||||
from langchain.agents.structured_output import ProviderStrategy
|
||||
from langchain_core.language_models.base import LangSmithParams
|
||||
|
||||
# Create a mock OpenAI model
|
||||
# Create a mock OpenAI model with model_name
|
||||
class MockOpenAIModel:
|
||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
|
||||
return LangSmithParams(ls_provider="openai", ls_model_type="chat")
|
||||
model_name: str = "gpt-4.1"
|
||||
|
||||
# Create a mock Grok/X.AI model
|
||||
# Create a mock Grok/X.AI model with model_name
|
||||
class MockGrokModel:
|
||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
|
||||
return LangSmithParams(ls_provider="xai", ls_model_type="chat")
|
||||
model_name: str = "grok-beta"
|
||||
|
||||
provider_strategy = ProviderStrategy(WeatherBaseModel)
|
||||
|
||||
@@ -843,25 +819,14 @@ def test_provider_strategy_strict_only_for_openai() -> None:
|
||||
def test_provider_strategy_validation() -> None:
|
||||
"""Test that ProviderStrategy validates provider support at agent invocation time."""
|
||||
from langchain.agents.structured_output import ProviderStrategy
|
||||
from langchain_core.language_models.base import LangSmithParams
|
||||
|
||||
# Create a mock model from an unsupported provider (e.g., Anthropic)
|
||||
# Use a model_name that doesn't match any supported patterns
|
||||
class MockAnthropicModel(FakeToolCallingModel):
|
||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithParams:
|
||||
return LangSmithParams(ls_provider="anthropic", ls_model_type="chat")
|
||||
|
||||
# Create a mock model without _get_ls_params
|
||||
class MockModelNoLSParams(FakeToolCallingModel):
|
||||
def _get_ls_params(self, **kwargs: Any):
|
||||
msg = "This model doesn't support _get_ls_params"
|
||||
raise AttributeError(msg)
|
||||
model_name: str = "claude-3-5-sonnet-20241022"
|
||||
|
||||
# Test unsupported provider: should raise ValueError when invoking agent
|
||||
anthropic_model = MockAnthropicModel(tool_calls=[[]])
|
||||
agent = create_agent(anthropic_model, [], response_format=ProviderStrategy(WeatherBaseModel))
|
||||
with pytest.raises(ValueError, match="does not support provider 'anthropic'"):
|
||||
with pytest.raises(ValueError, match="ProviderStrategy does not support this model"):
|
||||
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
# Test model without proper _get_ls_params: still works if model has the method
|
||||
# (validation checks hasattr and calls it)
|
||||
# We can't easily test the "no _get_ls_params" case without breaking BaseChatModel
|
||||
|
||||
Reference in New Issue
Block a user