anthropic: Allow kwargs to pass through when counting tokens (#31082)

- **Description:** `ChatAnthropic.get_num_tokens_from_messages` does not
currently receive `kwargs` and pass those on to
`self._client.beta.messages.count_tokens`. This is a problem if you need
to pass specific options to `count_tokens`, such as the `thinking`
option. This PR fixes that.
- **Issue:** N/A
- **Dependencies:** None
- **Twitter handle:** @bengladwell

Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Ben Gladwell 2025-04-30 17:56:22 -04:00 committed by GitHub
parent 918c950737
commit da59eb7eb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 1 deletions

View File

@ -1588,6 +1588,7 @@ class ChatAnthropic(BaseChatModel):
tools: Optional[
Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
] = None,
**kwargs: Any,
) -> int:
"""Count tokens in a sequence of input messages.
@ -1647,7 +1648,6 @@ class ChatAnthropic(BaseChatModel):
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
"""
formatted_system, formatted_messages = _format_messages(messages)
kwargs: dict[str, Any] = {}
if isinstance(formatted_system, str):
kwargs["system"] = formatted_system
if tools:

View File

@ -2,7 +2,9 @@
import os
from typing import Any, Callable, Literal, cast
from unittest.mock import patch
import anthropic
import pytest
from anthropic.types import Message, TextBlock, Usage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
@ -940,3 +942,15 @@ def test_optional_description() -> None:
sample_field: str
_ = llm.with_structured_output(SampleModel.model_json_schema())
def test_get_num_tokens_from_messages_passes_kwargs() -> None:
"""Test that get_num_tokens_from_messages passes kwargs to the model."""
llm = ChatAnthropic(model="claude-3-5-haiku-latest")
with patch.object(anthropic, "Client") as _Client:
llm.get_num_tokens_from_messages([HumanMessage("foo")], foo="bar")
assert (
_Client.return_value.beta.messages.count_tokens.call_args.kwargs["foo"] == "bar"
)