diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index e32c1279261..72f53b469cd 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1146,7 +1146,7 @@ class ChatAnthropic(BaseChatModel): model: str = Field(alias="model_name") """Model name to use.""" - max_tokens: int = Field(default=1024, alias="max_tokens_to_sample") + max_tokens: Optional[int] = Field(default=None, alias="max_tokens_to_sample") """Denotes the number of tokens to predict per generation.""" temperature: Optional[float] = None @@ -1276,6 +1276,29 @@ class ChatAnthropic(BaseChatModel): ls_params["ls_stop"] = ls_stop return ls_params + @model_validator(mode="before") + @classmethod + def validate_max_tokens(cls, values: dict[str, Any]) -> Any: + """Validate max_tokens. + + Can find the Max Tokens limits here: https://docs.anthropic.com/en/docs/about-claude/models/overview#model-comparison-table + """ + if values.get("max_tokens") is None and values.get("model"): + if "claude-opus-4" in values.get("model", ""): + values["max_tokens"] = 32000 + elif "claude-sonnet-4" in values.get( + "model", "" + ) or "claude-3-7-sonnet" in values.get("model", ""): + values["max_tokens"] = 64000 + elif "claude-3-5-sonnet" in values.get( + "model", "" + ) or "claude-3-5-haiku" in values.get("model", ""): + values["max_tokens"] = 8192 + # leaves us with "claude-3-haiku" + else: + values["max_tokens"] = 4096 + return values + @model_validator(mode="before") @classmethod def build_extra(cls, values: dict) -> Any: diff --git a/libs/partners/anthropic/tests/cassettes/TestAnthropicStandard.test_stream_time.yaml.gz b/libs/partners/anthropic/tests/cassettes/TestAnthropicStandard.test_stream_time.yaml.gz index 10a3e400719..af60833e15a 100644 Binary files a/libs/partners/anthropic/tests/cassettes/TestAnthropicStandard.test_stream_time.yaml.gz and b/libs/partners/anthropic/tests/cassettes/TestAnthropicStandard.test_stream_time.yaml.gz differ diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 93605b4dded..a626e8a81bd 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -62,6 +62,45 @@ def test_anthropic_client_caching() -> None: assert llm1._client._client is not llm5._client._client +def test_validate_max_tokens() -> None: + """Test the validate_max_tokens function through class initialization.""" + # Test claude-opus-4 models + llm = ChatAnthropic(model="claude-opus-4-20250514", anthropic_api_key="test") + assert llm.max_tokens == 32000 + + # Test claude-sonnet-4 models + llm = ChatAnthropic(model="claude-sonnet-4-latest", anthropic_api_key="test") + assert llm.max_tokens == 64000 + + # Test claude-3-7-sonnet models + llm = ChatAnthropic(model="claude-3-7-sonnet-latest", anthropic_api_key="test") + assert llm.max_tokens == 64000 + + # Test claude-3-5-sonnet models + llm = ChatAnthropic(model="claude-3-5-sonnet-latest", anthropic_api_key="test") + assert llm.max_tokens == 8192 + + # Test claude-3-5-haiku models + llm = ChatAnthropic(model="claude-3-5-haiku-latest", anthropic_api_key="test") + assert llm.max_tokens == 8192 + + # Test claude-3-haiku models (should default to 4096) + llm = ChatAnthropic(model="claude-3-haiku-latest", anthropic_api_key="test") + assert llm.max_tokens == 4096 + + # Test that existing max_tokens values are preserved + llm = ChatAnthropic( + model="claude-3-5-sonnet-latest", max_tokens=2048, anthropic_api_key="test" + ) + assert llm.max_tokens == 2048 + + # Test that explicitly set max_tokens values are preserved + llm = ChatAnthropic( + model="claude-3-5-sonnet-latest", max_tokens=4096, anthropic_api_key="test" + ) + assert llm.max_tokens == 4096 + + @pytest.mark.requires("anthropic") def test_anthropic_model_name_param() -> None: llm = ChatAnthropic(model_name="foo") # type: ignore[call-arg, call-arg]