anthropic: Allow kwargs to pass through when counting tokens (#31082)

- **Description:** `ChatAnthropic.get_num_tokens_from_messages` does not
currently receive `kwargs` and pass those on to
`self._client.beta.messages.count_tokens`. This is a problem if you need
to pass specific options to `count_tokens`, such as the `thinking`
option. This PR fixes that.
- **Issue:** N/A
- **Dependencies:** None
- **Twitter handle:** @bengladwell

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Ben Gladwell 2025-04-30 17:56:22 -04:00 committed by GitHub
parent 918c950737
commit da59eb7eb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 1 deletions

View File

@ -1588,6 +1588,7 @@ class ChatAnthropic(BaseChatModel):
tools: Optional[ tools: Optional[
Sequence[Union[dict[str, Any], type, Callable, BaseTool]] Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None, ] = None,
**kwargs: Any,
) -> int: ) -> int:
"""Count tokens in a sequence of input messages. """Count tokens in a sequence of input messages.
@ -1647,7 +1648,6 @@ class ChatAnthropic(BaseChatModel):
https://docs.anthropic.com/en/docs/build-with-claude/token-counting https://docs.anthropic.com/en/docs/build-with-claude/token-counting
""" """
formatted_system, formatted_messages = _format_messages(messages) formatted_system, formatted_messages = _format_messages(messages)
kwargs: dict[str, Any] = {}
if isinstance(formatted_system, str): if isinstance(formatted_system, str):
kwargs["system"] = formatted_system kwargs["system"] = formatted_system
if tools: if tools:

View File

@ -2,7 +2,9 @@
import os import os
from typing import Any, Callable, Literal, cast from typing import Any, Callable, Literal, cast
from unittest.mock import patch
import anthropic
import pytest import pytest
from anthropic.types import Message, TextBlock, Usage from anthropic.types import Message, TextBlock, Usage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
@ -940,3 +942,15 @@ def test_optional_description() -> None:
sample_field: str sample_field: str
_ = llm.with_structured_output(SampleModel.model_json_schema()) _ = llm.with_structured_output(SampleModel.model_json_schema())
def test_get_num_tokens_from_messages_passes_kwargs() -> None:
"""Test that get_num_tokens_from_messages passes kwargs to the model."""
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
with patch.object(anthropic, "Client") as _Client:
llm.get_num_tokens_from_messages([HumanMessage("foo")], foo="bar")
assert (
_Client.return_value.beta.messages.count_tokens.call_args.kwargs["foo"] == "bar"
)