mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
feat(langchain): reference model profiles for provider strategy (#33974)
This commit is contained in:
@@ -63,6 +63,18 @@ if TYPE_CHECKING:
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
|
||||
# if langchain-model-profiles is not installed, these models are assumed to support
|
||||
# structured output
|
||||
"grok",
|
||||
"gpt-5",
|
||||
"gpt-4.1",
|
||||
"gpt-4o",
|
||||
"gpt-oss",
|
||||
"o3-pro",
|
||||
"o3-mini",
|
||||
]
|
||||
|
||||
|
||||
def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
|
||||
"""Normalize middleware return value to ModelResponse."""
|
||||
@@ -349,11 +361,13 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
|
||||
return []
|
||||
|
||||
|
||||
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
|
||||
"""Check if a model supports provider-specific structured output.
|
||||
|
||||
Args:
|
||||
model: Model name string or `BaseChatModel` instance.
|
||||
tools: Optional list of tools provided to the agent. Needed because some models
|
||||
don't support structured output together with tool calling.
|
||||
|
||||
Returns:
|
||||
`True` if the model supports provider-specific structured output, `False` otherwise.
|
||||
@@ -362,11 +376,26 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
if isinstance(model, str):
|
||||
model_name = model
|
||||
elif isinstance(model, BaseChatModel):
|
||||
model_name = getattr(model, "model_name", None)
|
||||
model_name = (
|
||||
getattr(model, "model_name", None)
|
||||
or getattr(model, "model", None)
|
||||
or getattr(model, "model_id", "")
|
||||
)
|
||||
try:
|
||||
model_profile = model.profile
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
if (
|
||||
model_profile.get("structured_output")
|
||||
# We make an exception for Gemini models, which currently do not support
|
||||
# simultaneous tool use with structured output
|
||||
and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
|
||||
):
|
||||
return True
|
||||
|
||||
return (
|
||||
"grok" in model_name.lower()
|
||||
or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
|
||||
any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
|
||||
if model_name
|
||||
else False
|
||||
)
|
||||
@@ -988,7 +1017,7 @@ def create_agent( # noqa: PLR0915
|
||||
effective_response_format: ResponseFormat | None
|
||||
if isinstance(request.response_format, AutoStrategy):
|
||||
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
|
||||
if _supports_provider_strategy(request.model):
|
||||
if _supports_provider_strategy(request.model, tools=request.tools):
|
||||
# Model supports provider strategy - use it
|
||||
effective_response_format = ProviderStrategy(schema=request.response_format.schema)
|
||||
else:
|
||||
|
||||
@@ -57,6 +57,7 @@ test = [
|
||||
"pytest-mock",
|
||||
"syrupy>=4.0.2,<5.0.0",
|
||||
"toml>=0.10.2,<1.0.0",
|
||||
"langchain-model-profiles",
|
||||
"langchain-tests",
|
||||
"langchain-openai",
|
||||
]
|
||||
@@ -75,6 +76,7 @@ test_integration = [
|
||||
"cassio>=0.1.0,<1.0.0",
|
||||
"langchainhub>=0.1.16,<1.0.0",
|
||||
"langchain-core",
|
||||
"langchain-model-profiles",
|
||||
"langchain-text-splitters",
|
||||
]
|
||||
|
||||
@@ -83,6 +85,7 @@ prerelease = "allow"
|
||||
|
||||
[tool.uv.sources]
|
||||
langchain-core = { path = "../core", editable = true }
|
||||
langchain-model-profiles = { path = "../model-profiles", editable = true }
|
||||
langchain-tests = { path = "../standard-tests", editable = true }
|
||||
langchain-text-splitters = { path = "../text-splitters", editable = true }
|
||||
langchain-openai = { path = "../partners/openai", editable = true }
|
||||
|
||||
@@ -790,7 +790,7 @@ class TestDynamicModelWithResponseFormat:
|
||||
# Track which model is checked for provider strategy support
|
||||
calls = []
|
||||
|
||||
def mock_supports_provider_strategy(model) -> bool:
|
||||
def mock_supports_provider_strategy(model, tools) -> bool:
|
||||
"""Track which model is checked and return True for ProviderStrategy."""
|
||||
calls.append(model)
|
||||
return True
|
||||
|
||||
43
libs/langchain_v1/uv.lock
generated
43
libs/langchain_v1/uv.lock
generated
@@ -995,7 +995,7 @@ name = "exceptiongroup"
|
||||
version = "1.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.12'" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
|
||||
wheels = [
|
||||
@@ -1991,6 +1991,7 @@ lint = [
|
||||
{ name = "ruff" },
|
||||
]
|
||||
test = [
|
||||
{ name = "langchain-model-profiles" },
|
||||
{ name = "langchain-openai" },
|
||||
{ name = "langchain-tests" },
|
||||
{ name = "pytest" },
|
||||
@@ -2006,6 +2007,7 @@ test = [
|
||||
test-integration = [
|
||||
{ name = "cassio" },
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langchain-model-profiles" },
|
||||
{ name = "langchain-text-splitters" },
|
||||
{ name = "langchainhub" },
|
||||
{ name = "python-dotenv" },
|
||||
@@ -2031,7 +2033,7 @@ requires-dist = [
|
||||
{ name = "langchain-groq", marker = "extra == 'groq'" },
|
||||
{ name = "langchain-huggingface", marker = "extra == 'huggingface'" },
|
||||
{ name = "langchain-mistralai", marker = "extra == 'mistralai'" },
|
||||
{ name = "langchain-model-profiles", marker = "extra == 'model-profiles'" },
|
||||
{ name = "langchain-model-profiles", marker = "extra == 'model-profiles'", editable = "../model-profiles" },
|
||||
{ name = "langchain-ollama", marker = "extra == 'ollama'" },
|
||||
{ name = "langchain-openai", marker = "extra == 'openai'", editable = "../partners/openai" },
|
||||
{ name = "langchain-perplexity", marker = "extra == 'perplexity'" },
|
||||
@@ -2045,6 +2047,7 @@ provides-extras = ["model-profiles", "community", "anthropic", "openai", "azure-
|
||||
[package.metadata.requires-dev]
|
||||
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13.0" }]
|
||||
test = [
|
||||
{ name = "langchain-model-profiles", editable = "../model-profiles" },
|
||||
{ name = "langchain-openai", editable = "../partners/openai" },
|
||||
{ name = "langchain-tests", editable = "../standard-tests" },
|
||||
{ name = "pytest", specifier = ">=8.0.0,<9.0.0" },
|
||||
@@ -2060,6 +2063,7 @@ test = [
|
||||
test-integration = [
|
||||
{ name = "cassio", specifier = ">=0.1.0,<1.0.0" },
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "langchain-model-profiles", editable = "../model-profiles" },
|
||||
{ name = "langchain-text-splitters", editable = "../text-splitters" },
|
||||
{ name = "langchainhub", specifier = ">=0.1.16,<1.0.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0,<2.0.0" },
|
||||
@@ -2339,14 +2343,41 @@ wheels = [
|
||||
[[package]]
|
||||
name = "langchain-model-profiles"
|
||||
version = "0.0.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
source = { editable = "../model-profiles" }
|
||||
dependencies = [
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/69/c3/cbadf3e884bfbd57f4604a68d67132ece45c3510a1ec710a5193b4b1a1af/langchain_model_profiles-0.0.4.tar.gz", hash = "sha256:b66909339c9175a6963e7fcdacae382b4773f8da04092b9dd64424b8e8b1f8c8", size = 145898, upload-time = "2025-11-10T17:08:44.875Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/50/f0/a848f99a9d70f40c2f46c8d9465549adc6e8b678658be3ca11ce16287e24/langchain_model_profiles-0.0.4-py3-none-any.whl", hash = "sha256:7382b7feb2294ded84fe89b20a4d656f81f40c8f024233205b2c5507391ce1ba", size = 30419, upload-time = "2025-11-10T17:08:43.948Z" },
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'", specifier = ">=2.0.0,<3.0.0" },
|
||||
{ name = "typing-extensions", specifier = ">=4.7.0,<5.0.0" },
|
||||
]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "httpx", specifier = ">=0.23.0,<1" }]
|
||||
lint = [
|
||||
{ name = "langchain", editable = "." },
|
||||
{ name = "ruff", specifier = ">=0.12.2,<0.13.0" },
|
||||
]
|
||||
test = [
|
||||
{ name = "langchain", extras = ["openai"], editable = "." },
|
||||
{ name = "langchain-core", editable = "../core" },
|
||||
{ name = "pytest", specifier = ">=8.0.0,<9.0.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.23.2,<2.0.0" },
|
||||
{ name = "pytest-cov", specifier = ">=4.0.0,<8.0.0" },
|
||||
{ name = "pytest-mock" },
|
||||
{ name = "pytest-socket", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "pytest-watcher", specifier = ">=0.2.6,<1.0.0" },
|
||||
{ name = "pytest-xdist", specifier = ">=3.6.1,<4.0.0" },
|
||||
{ name = "syrupy", specifier = ">=4.0.2,<5.0.0" },
|
||||
{ name = "toml", specifier = ">=0.10.2,<1.0.0" },
|
||||
]
|
||||
test-integration = [{ name = "langchain-core", editable = "../core" }]
|
||||
typing = [
|
||||
{ name = "mypy", specifier = ">=1.18.1,<1.19.0" },
|
||||
{ name = "types-toml", specifier = ">=0.10.8.20240310,<1.0.0.0" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user