feat(langchain): reference model profiles for provider strategy (#33974)

This commit is contained in:
ccurme
2025-11-14 14:24:18 -05:00
committed by GitHub
parent 189dcf7295
commit 6aa3794b74
4 changed files with 75 additions and 12 deletions

View File

@@ -63,6 +63,18 @@ if TYPE_CHECKING:
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
# if langchain-model-profiles is not installed, these models are assumed to support
# structured output
"grok",
"gpt-5",
"gpt-4.1",
"gpt-4o",
"gpt-oss",
"o3-pro",
"o3-mini",
]
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
"""Normalize middleware return value to ModelResponse."""
@@ -349,11 +361,13 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
return []
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
"""Check if a model supports provider-specific structured output.
Args:
model: Model name string or `BaseChatModel` instance.
tools: Optional list of tools provided to the agent. Needed because some models
don't support structured output together with tool calling.
Returns:
`True` if the model supports provider-specific structured output, `False` otherwise.
@@ -362,11 +376,26 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
if isinstance(model, str):
model_name = model
elif isinstance(model, BaseChatModel):
model_name = getattr(model, "model_name", None)
model_name = (
getattr(model, "model_name", None)
or getattr(model, "model", None)
or getattr(model, "model_id", "")
)
try:
model_profile = model.profile
except ImportError:
pass
else:
if (
model_profile.get("structured_output")
# We make an exception for Gemini models, which currently do not support
# simultaneous tool use with structured output
and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
):
return True
return (
"grok" in model_name.lower()
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
if model_name
else False
)
@@ -988,7 +1017,7 @@ def create_agent( # noqa: PLR0915
effective_response_format: ResponseFormat | None
if isinstance(request.response_format, AutoStrategy):
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
if _supports_provider_strategy(request.model):
if _supports_provider_strategy(request.model, tools=request.tools):
# Model supports provider strategy - use it
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
else:

View File

@@ -57,6 +57,7 @@ test = [
"pytest-mock",
"syrupy>=4.0.2,<5.0.0",
"toml>=0.10.2,<1.0.0",
"langchain-model-profiles",
"langchain-tests",
"langchain-openai",
]
@@ -75,6 +76,7 @@ test_integration = [
"cassio>=0.1.0,<1.0.0",
"langchainhub>=0.1.16,<1.0.0",
"langchain-core",
"langchain-model-profiles",
"langchain-text-splitters",
]
@@ -83,6 +85,7 @@ prerelease = "allow"
[tool.uv.sources]
langchain-core = { path = "../core", editable = true }
langchain-model-profiles = { path = "../model-profiles", editable = true }
langchain-tests = { path = "../standard-tests", editable = true }
langchain-text-splitters = { path = "../text-splitters", editable = true }
langchain-openai = { path = "../partners/openai", editable = true }

View File

@@ -790,7 +790,7 @@ class TestDynamicModelWithResponseFormat:
# Track which model is checked for provider strategy support
calls = []
def mock_supports_provider_strategy(model) -> bool:
def mock_supports_provider_strategy(model, tools) -> bool:
"""Track which model is checked and return True for ProviderStrategy."""
calls.append(model)
return True

View File

@@ -995,7 +995,7 @@ name = "exceptiongroup"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
wheels = [
@@ -1991,6 +1991,7 @@ lint = [
{ name = "ruff" },
]
test = [
{ name = "langchain-model-profiles" },
{ name = "langchain-openai" },
{ name = "langchain-tests" },
{ name = "pytest" },
@@ -2006,6 +2007,7 @@ test = [
test-integration = [
{ name = "cassio" },
{ name = "langchain-core" },
{ name = "langchain-model-profiles" },
{ name = "langchain-text-splitters" },
{ name = "langchainhub" },
{ name = "python-dotenv" },
@@ -2031,7 +2033,7 @@ requires-dist = [
{ name = "langchain-groq", marker = "extra == 'groq'" },
{ name = "langchain-huggingface", marker = "extra == 'huggingface'" },
{ name = "langchain-mistralai", marker = "extra == 'mistralai'" },
{ name = "langchain-model-profiles", marker = "extra == 'model-profiles'" },
{ name = "langchain-model-profiles", marker = "extra == 'model-profiles'", editable = "../model-profiles" },
{ name = "langchain-ollama", marker = "extra == 'ollama'" },
{ name = "langchain-openai", marker = "extra == 'openai'", editable = "../partners/openai" },
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
@@ -2045,6 +2047,7 @@ provides-extras = ["model-profiles", "community", "anthropic", "openai", "azure-
[package.metadata.requires-dev]
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13.0" }]
test = [
{ name = "langchain-model-profiles", editable = "../model-profiles" },
{ name = "langchain-openai", editable = "../partners/openai" },
{ name = "langchain-tests", editable = "../standard-tests" },
{ name = "pytest", specifier = ">=8.0.0,<9.0.0" },
@@ -2060,6 +2063,7 @@ test = [
test-integration = [
{ name = "cassio", specifier = ">=0.1.0,<1.0.0" },
{ name = "langchain-core", editable = "../core" },
{ name = "langchain-model-profiles", editable = "../model-profiles" },
{ name = "langchain-text-splitters", editable = "../text-splitters" },
{ name = "langchainhub", specifier = ">=0.1.16,<1.0.0" },
{ name = "python-dotenv", specifier = ">=1.0.0,<2.0.0" },
@@ -2339,14 +2343,41 @@ wheels = [
[[package]]
name = "langchain-model-profiles"
version = "0.0.4"
source = { registry = "https://pypi.org/simple" }
source = { editable = "../model-profiles" }
dependencies = [
{ name = "tomli", marker = "python_full_version < '3.11'" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/69/c3/cbadf3e884bfbd57f4604a68d67132ece45c3510a1ec710a5193b4b1a1af/langchain_model_profiles-0.0.4.tar.gz", hash = "sha256:b66909339c9175a6963e7fcdacae382b4773f8da04092b9dd64424b8e8b1f8c8", size = 145898, upload-time = "2025-11-10T17:08:44.875Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/50/f0/a848f99a9d70f40c2f46c8d9465549adc6e8b678658be3ca11ce16287e24/langchain_model_profiles-0.0.4-py3-none-any.whl", hash = "sha256:7382b7feb2294ded84fe89b20a4d656f81f40c8f024233205b2c5507391ce1ba", size = 30419, upload-time = "2025-11-10T17:08:43.948Z" },
[package.metadata]
requires-dist = [
{ name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0,<3.0.0" },
{ name = "typing-extensions", specifier = ">=4.7.0,<5.0.0" },
]
[package.metadata.requires-dev]
dev = [{ name = "httpx", specifier = ">=0.23.0,<1" }]
lint = [
{ name = "langchain", editable = "." },
{ name = "ruff", specifier = ">=0.12.2,<0.13.0" },
]
test = [
{ name = "langchain", extras = ["openai"], editable = "." },
{ name = "langchain-core", editable = "../core" },
{ name = "pytest", specifier = ">=8.0.0,<9.0.0" },
{ name = "pytest-asyncio", specifier = ">=0.23.2,<2.0.0" },
{ name = "pytest-cov", specifier = ">=4.0.0,<8.0.0" },
{ name = "pytest-mock" },
{ name = "pytest-socket", specifier = ">=0.6.0,<1.0.0" },
{ name = "pytest-watcher", specifier = ">=0.2.6,<1.0.0" },
{ name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" },
{ name = "syrupy", specifier = ">=4.0.2,<5.0.0" },
{ name = "toml", specifier = ">=0.10.2,<1.0.0" },
]
test-integration = [{ name = "langchain-core", editable = "../core" }]
typing = [
{ name = "mypy", specifier = ">=1.18.1,<1.19.0" },
{ name = "types-toml", specifier = ">=0.10.8.20240310,<1.0.0.0" },
]
[[package]]