feat(anthropic): use model profile for max output tokens (#34163)

Need(?) to adjust tests to also pull from model profile? currently
hardcoded
This commit is contained in:
Mason Daugherty
2025-12-08 15:31:16 -05:00
committed by GitHub
parent dcb670f395
commit 91d5ca275d

View File

@@ -89,34 +89,9 @@ def _get_default_model_profile(model_name: str) -> ModelProfile:
return {} return {}
_MODEL_DEFAULT_MAX_OUTPUT_TOKENS: Final[dict[str, int]] = {
# Listed old to new
"claude-3-haiku": 4096, # Claude Haiku 3
"claude-3-5-haiku": 8192, # Claude Haiku 3.5
"claude-3-7-sonnet": 64000, # Claude Sonnet 3.7
"claude-sonnet-4": 64000, # Claude Sonnet 4
"claude-opus-4": 32000, # Claude Opus 4
"claude-opus-4-1": 32000, # Claude Opus 4.1
"claude-sonnet-4-5": 64000, # Claude Sonnet 4.5
"claude-haiku-4-5": 64000, # Claude Haiku 4.5
}
_FALLBACK_MAX_OUTPUT_TOKENS: Final[int] = 4096 _FALLBACK_MAX_OUTPUT_TOKENS: Final[int] = 4096
def _default_max_tokens_for(model: str | None) -> int:
"""Return the default max output tokens for an Anthropic model (with fallback).
See the Claude docs for [Max Tokens limits](https://platform.claude.com/docs/en/about-claude/models/overview#model-comparison-table).
"""
if not model:
return _FALLBACK_MAX_OUTPUT_TOKENS
parts = model.split("-")
family = "-".join(parts[:-1]) if len(parts) > 1 else model
return _MODEL_DEFAULT_MAX_OUTPUT_TOKENS.get(family, _FALLBACK_MAX_OUTPUT_TOKENS)
class AnthropicTool(TypedDict): class AnthropicTool(TypedDict):
"""Anthropic tool definition.""" """Anthropic tool definition."""
@@ -1869,10 +1844,13 @@ class ChatAnthropic(BaseChatModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def set_default_max_tokens(cls, values: dict[str, Any]) -> Any: def set_default_max_tokens(cls, values: dict[str, Any]) -> Any:
"""Set default `max_tokens`.""" """Set default `max_tokens` from model profile with fallback."""
if values.get("max_tokens") is None: if values.get("max_tokens") is None:
model = values.get("model") or values.get("model_name") model = values.get("model") or values.get("model_name")
values["max_tokens"] = _default_max_tokens_for(model) profile = _get_default_model_profile(model) if model else {}
values["max_tokens"] = profile.get(
"max_output_tokens", _FALLBACK_MAX_OUTPUT_TOKENS
)
return values return values
@model_validator(mode="before") @model_validator(mode="before")