This commit is contained in:
Mohammad Mohtashim 2025-07-29 00:50:55 +00:00 committed by GitHub
commit 35ef249700
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 1 deletions

View File

@ -1146,7 +1146,7 @@ class ChatAnthropic(BaseChatModel):
model: str = Field(alias="model_name")
"""Model name to use."""
max_tokens: int = Field(default=1024, alias="max_tokens_to_sample")
max_tokens: Optional[int] = Field(default=None, alias="max_tokens_to_sample")
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
@ -1276,6 +1276,29 @@ class ChatAnthropic(BaseChatModel):
ls_params["ls_stop"] = ls_stop
return ls_params
@model_validator(mode="before")
@classmethod
def validate_max_tokens(cls, values: dict[str, Any]) -> Any:
"""Validate max_tokens.
Can find the Max Tokens limits here: https://docs.anthropic.com/en/docs/about-claude/models/overview#model-comparison-table
"""
if values.get("max_tokens") is None and values.get("model"):
if "claude-opus-4" in values.get("model", ""):
values["max_tokens"] = 32000
elif "claude-sonnet-4" in values.get(
"model", ""
) or "claude-3-7-sonnet" in values.get("model", ""):
values["max_tokens"] = 64000
elif "claude-3-5-sonnet" in values.get(
"model", ""
) or "claude-3-5-haiku" in values.get("model", ""):
values["max_tokens"] = 8192
# leaves us with "claude-3-haiku"
else:
values["max_tokens"] = 4096
return values
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: dict) -> Any:

View File

@ -62,6 +62,45 @@ def test_anthropic_client_caching() -> None:
assert llm1._client._client is not llm5._client._client
def test_validate_max_tokens() -> None:
"""Test the validate_max_tokens function through class initialization."""
# Test claude-opus-4 models
llm = ChatAnthropic(model="claude-opus-4-20250514", anthropic_api_key="test")
assert llm.max_tokens == 32000
# Test claude-sonnet-4 models
llm = ChatAnthropic(model="claude-sonnet-4-latest", anthropic_api_key="test")
assert llm.max_tokens == 64000
# Test claude-3-7-sonnet models
llm = ChatAnthropic(model="claude-3-7-sonnet-latest", anthropic_api_key="test")
assert llm.max_tokens == 64000
# Test claude-3-5-sonnet models
llm = ChatAnthropic(model="claude-3-5-sonnet-latest", anthropic_api_key="test")
assert llm.max_tokens == 8192
# Test claude-3-5-haiku models
llm = ChatAnthropic(model="claude-3-5-haiku-latest", anthropic_api_key="test")
assert llm.max_tokens == 8192
# Test claude-3-haiku models (should default to 4096)
llm = ChatAnthropic(model="claude-3-haiku-latest", anthropic_api_key="test")
assert llm.max_tokens == 4096
# Test that existing max_tokens values are preserved
llm = ChatAnthropic(
model="claude-3-5-sonnet-latest", max_tokens=2048, anthropic_api_key="test"
)
assert llm.max_tokens == 2048
# Test that explicitly set max_tokens values are preserved
llm = ChatAnthropic(
model="claude-3-5-sonnet-latest", max_tokens=4096, anthropic_api_key="test"
)
assert llm.max_tokens == 4096
@pytest.mark.requires("anthropic")
def test_anthropic_model_name_param() -> None:
llm = ChatAnthropic(model_name="foo") # type: ignore[call-arg, call-arg]