mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
11 Commits
langchain=
...
strategy-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c16969d8ac | ||
|
|
038e78fa95 | ||
|
|
e18a21a8b0 | ||
|
|
c8b1162067 | ||
|
|
d8d94e6c95 | ||
|
|
e4ab6896a7 | ||
|
|
6efea75d1c | ||
|
|
5155fbe072 | ||
|
|
7157c2f69c | ||
|
|
c97e649bd3 | ||
|
|
602284dc3e |
@@ -13,7 +13,6 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
@@ -42,6 +41,7 @@ from langchain.agents.structured_output import (
|
||||
ResponseFormat,
|
||||
StructuredOutputValidationError,
|
||||
ToolStrategy,
|
||||
_supports_provider_strategy,
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
|
||||
@@ -49,6 +49,7 @@ from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.runnables import Runnable
|
||||
from langgraph.cache.base import BaseCache
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
@@ -347,29 +348,6 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
return []
|
||||
|
||||
|
||||
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model: Model name string or `BaseChatModel` instance.
|
||||
|
||||
Returns:
|
||||
`True` if the model supports provider-specific structured output, `False` otherwise.
|
||||
"""
|
||||
model_name: str | None = None
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
elif isinstance(model, BaseChatModel):
|
||||
model_name = getattr(model, "model_name", None)
|
||||
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
def _handle_structured_output_error(
|
||||
exception: Exception,
|
||||
response_format: ResponseFormat,
|
||||
@@ -932,16 +910,34 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Determine effective response format (auto-detect if needed)
|
||||
effective_response_format: ResponseFormat | None
|
||||
model_name: str = cast(
|
||||
"str",
|
||||
(
|
||||
request.model
|
||||
if isinstance(request.model, str)
|
||||
else getattr(request.model, "model_name", "")
|
||||
),
|
||||
)
|
||||
if isinstance(request.response_format, AutoStrategy):
|
||||
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
|
||||
if _supports_provider_strategy(request.model):
|
||||
if _supports_provider_strategy(model_name):
|
||||
# Model supports provider strategy - use it
|
||||
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
|
||||
else:
|
||||
# Model doesn't support provider strategy - use ToolStrategy
|
||||
effective_response_format = ToolStrategy(schema=request.response_format.schema)
|
||||
elif isinstance(request.response_format, ProviderStrategy):
|
||||
if not _supports_provider_strategy(model_name):
|
||||
msg = (
|
||||
f"Cannot use ProviderStrategy with {model_name}. "
|
||||
"Supported models: OpenAI (gpt-5, gpt-4.1, gpt-oss, o3-pro, o3-mini), "
|
||||
"X.AI (Grok). "
|
||||
"Consider using a raw schema (which auto-selects the best strategy) or "
|
||||
"explicitly use `ToolStrategy` for unsupported providers."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
effective_response_format = request.response_format
|
||||
else:
|
||||
# User explicitly specified a strategy - preserve it
|
||||
effective_response_format = request.response_format
|
||||
|
||||
# Build final tools list including structured output tools
|
||||
@@ -957,12 +953,9 @@ def create_agent( # noqa: PLR0915
|
||||
if isinstance(effective_response_format, ProviderStrategy):
|
||||
# Use provider-specific structured output
|
||||
kwargs = effective_response_format.to_model_kwargs()
|
||||
return (
|
||||
request.model.bind_tools(
|
||||
final_tools, strict=True, **kwargs, **request.model_settings
|
||||
),
|
||||
effective_response_format,
|
||||
)
|
||||
return request.model.bind_tools(
|
||||
final_tools, **kwargs, **request.model_settings
|
||||
), effective_response_format
|
||||
|
||||
if isinstance(effective_response_format, ToolStrategy):
|
||||
# Current implementation requires that tools used for structured output
|
||||
|
||||
@@ -31,6 +31,23 @@ SchemaT = TypeVar("SchemaT")
|
||||
SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
|
||||
|
||||
|
||||
def _supports_provider_strategy(model_name: str) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model_name: Model name string.
|
||||
|
||||
Returns:
|
||||
`True` if the model supports provider-specific structured output, `False` otherwise.
|
||||
"""
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
class StructuredOutputError(Exception):
|
||||
"""Base class for structured output errors."""
|
||||
|
||||
@@ -238,7 +255,56 @@ class ToolStrategy(Generic[SchemaT]):
|
||||
|
||||
@dataclass(init=False)
|
||||
class ProviderStrategy(Generic[SchemaT]):
|
||||
"""Use the model provider's native structured output method."""
|
||||
"""Use the model provider's native structured output method.
|
||||
|
||||
`ProviderStrategy` uses provider-specific structured output APIs that enforce
|
||||
JSON schema validation at the model level. This provides stronger guarantees
|
||||
than tool-based approaches but is only supported by certain providers.
|
||||
|
||||
Supported Providers:
|
||||
- **OpenAI**: All models that support structured outputs (requires `strict=True`)
|
||||
- **X.AI (Grok)**: All models that support structured outputs (requires `strict=True`)
|
||||
|
||||
Important:
|
||||
When using `ProviderStrategy`, the agent will validate at runtime that the
|
||||
model provider is supported. If you're using an unsupported provider, consider:
|
||||
|
||||
- Using a **raw schema** (recommended): Automatically selects the best strategy
|
||||
based on model capabilities
|
||||
- Using **`ToolStrategy`**: Explicitly use tool-based structured output for any
|
||||
provider
|
||||
|
||||
Example:
|
||||
```python
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.structured_output import ProviderStrategy
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WeatherResponse(BaseModel):
|
||||
temperature: float
|
||||
condition: str
|
||||
|
||||
|
||||
# Explicitly use provider strategy (only for OpenAI/Grok)
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4", tools=[], response_format=ProviderStrategy(WeatherResponse)
|
||||
)
|
||||
|
||||
# Or use raw schema for automatic strategy selection (recommended)
|
||||
# This will auto-select ProviderStrategy for OpenAI/Grok, ToolStrategy for others
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4",
|
||||
tools=[],
|
||||
response_format=WeatherResponse, # Auto-selects best strategy
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
`ProviderStrategy` can be used with middleware that changes the model at runtime.
|
||||
Validation occurs after the model is resolved, allowing dynamic model selection
|
||||
while ensuring provider compatibility.
|
||||
"""
|
||||
|
||||
schema: type[SchemaT]
|
||||
"""Schema for native mode."""
|
||||
@@ -255,9 +321,19 @@ class ProviderStrategy(Generic[SchemaT]):
|
||||
self.schema_spec = _SchemaSpec(schema)
|
||||
|
||||
def to_model_kwargs(self) -> dict[str, Any]:
|
||||
"""Convert to kwargs to bind to a model to force structured output."""
|
||||
# OpenAI:
|
||||
# - see https://platform.openai.com/docs/guides/structured-outputs
|
||||
"""Convert to kwargs to bind to a model to force structured output.
|
||||
|
||||
Args:
|
||||
model: The model instance to check provider for conditional `strict` param.
|
||||
|
||||
Returns:
|
||||
Model kwargs with `response_format` and optionally `strict`.
|
||||
"""
|
||||
# Provider-specific structured output:
|
||||
# - OpenAI: https://platform.openai.com/docs/guides/structured-outputs
|
||||
# - Uses strict=True for schema validation
|
||||
# - X.AI (Grok): https://docs.x.ai/docs/guides/structured-outputs
|
||||
# - Uses strict=True for schema validation (required)
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
@@ -265,7 +341,8 @@ class ProviderStrategy(Generic[SchemaT]):
|
||||
"schema": self.schema_spec.json_schema,
|
||||
},
|
||||
}
|
||||
return {"response_format": response_format}
|
||||
|
||||
return {"response_format": response_format, "strict": True}
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import (
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
||||
from langchain_core.language_models.base import LangSmithParams
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@@ -29,6 +30,7 @@ class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
|
||||
structured_response: StructuredResponseT | None = None
|
||||
index: int = 0
|
||||
tool_style: Literal["openai", "anthropic"] = "openai"
|
||||
model_name: str = "fake-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
|
||||
@@ -619,7 +619,81 @@ class TestResponseFormatAsProviderStrategy:
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherBaseModel](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
|
||||
tool_calls=tool_calls,
|
||||
structured_response=EXPECTED_WEATHER_PYDANTIC,
|
||||
model_name="gpt-4.1",
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
def test_unsupported_model_raises_error(self) -> None:
|
||||
"""Test that ProviderStrategy raises ValueError for unsupported models."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
]
|
||||
|
||||
# Use a model name that doesn't support provider strategy
|
||||
model = FakeToolCallingModel[WeatherBaseModel](
|
||||
tool_calls=tool_calls,
|
||||
structured_response=EXPECTED_WEATHER_PYDANTIC,
|
||||
model_name="claude-3-5-sonnet",
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
r"Cannot use ProviderStrategy with claude-3-5-sonnet\. "
|
||||
r"Supported models: OpenAI \(gpt-5, gpt-4\.1, gpt-oss, o3-pro, o3-mini\), "
|
||||
r"X\.AI \(Grok\)\. "
|
||||
r"Consider using a raw schema \(which auto-selects the best strategy\) or "
|
||||
r"explicitly use `ToolStrategy` for unsupported providers\."
|
||||
),
|
||||
):
|
||||
agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
def test_supported_openai_models(self) -> None:
|
||||
"""Test that ProviderStrategy works with all supported OpenAI model variants."""
|
||||
supported_models = ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"]
|
||||
|
||||
for model_name in supported_models:
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherBaseModel](
|
||||
tool_calls=tool_calls,
|
||||
structured_response=EXPECTED_WEATHER_PYDANTIC,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model, [get_weather], response_format=ProviderStrategy(WeatherBaseModel)
|
||||
)
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
assert len(response["messages"]) == 4
|
||||
|
||||
def test_supported_grok_model(self) -> None:
|
||||
"""Test that ProviderStrategy works with Grok models."""
|
||||
tool_calls = [
|
||||
[{"args": {}, "id": "1", "name": "get_weather"}],
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherBaseModel](
|
||||
tool_calls=tool_calls,
|
||||
structured_response=EXPECTED_WEATHER_PYDANTIC,
|
||||
model_name="grok-beta",
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -637,7 +711,9 @@ class TestResponseFormatAsProviderStrategy:
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherDataclass](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
|
||||
tool_calls=tool_calls,
|
||||
structured_response=EXPECTED_WEATHER_DATACLASS,
|
||||
model_name="gpt-4.1",
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -657,7 +733,7 @@ class TestResponseFormatAsProviderStrategy:
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[WeatherTypedDict](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -675,7 +751,7 @@ class TestResponseFormatAsProviderStrategy:
|
||||
]
|
||||
|
||||
model = FakeToolCallingModel[dict](
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
|
||||
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT, model_name="gpt-4.1"
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
@@ -697,13 +773,13 @@ class TestDynamicModelWithResponseFormat:
|
||||
on the middleware-modified model (not the original), ensuring the correct strategy is
|
||||
selected based on the final model's capabilities.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
|
||||
# Custom model that we'll use to test whether the tool strategy is applied
|
||||
# correctly at runtime.
|
||||
# Custom model that we'll use to test whether the provider strategy is applied
|
||||
# correctly at runtime. Use a model_name that supports provider strategy.
|
||||
class CustomModel(GenericFakeChatModel):
|
||||
model_name: str = "gpt-4.1"
|
||||
tool_bindings: list[Any] = []
|
||||
|
||||
def bind_tools(
|
||||
@@ -736,14 +812,6 @@ class TestDynamicModelWithResponseFormat:
|
||||
request.model = model
|
||||
return handler(request)
|
||||
|
||||
# Track which model is checked for provider strategy support
|
||||
calls = []
|
||||
|
||||
def mock_supports_provider_strategy(model) -> bool:
|
||||
"""Track which model is checked and return True for ProviderStrategy."""
|
||||
calls.append(model)
|
||||
return True
|
||||
|
||||
# Use raw Pydantic model (not wrapped in ToolStrategy or ProviderStrategy)
|
||||
# This should auto-detect strategy based on model capabilities
|
||||
agent = create_agent(
|
||||
@@ -754,14 +822,7 @@ class TestDynamicModelWithResponseFormat:
|
||||
middleware=[ModelSwappingMiddleware()],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"langchain.agents.factory._supports_provider_strategy",
|
||||
side_effect=mock_supports_provider_strategy,
|
||||
):
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
# Verify strategy resolution was deferred: check was called once during _get_bound_model
|
||||
assert len(calls) == 1
|
||||
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
|
||||
|
||||
# Verify successful parsing of JSON as structured output via ProviderStrategy
|
||||
assert response["structured_response"] == EXPECTED_WEATHER_PYDANTIC
|
||||
|
||||
Reference in New Issue
Block a user