anthropic[major]: support python 3.13 (#27916)

Last week Anthropic released version 0.39.0 of its python sdk, which
enabled support for Python 3.13. This release deleted a legacy
`client.count_tokens` method, which we currently access during init of
the `Anthropic` LLM. Anthropic has replaced this functionality with the
[client.beta.messages.count_tokens()
API](https://github.com/anthropics/anthropic-sdk-python/pull/726).

To enable support for `anthropic >= 0.39.0` and Python 3.13, here we
drop support for the legacy token counting method, and add support for
the new method via `ChatAnthropic.get_num_tokens_from_messages`.

To fully support the token counting API, we update the signature of
`get_num_tokens_from_message` to accept tools everywhere.

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
ccurme
2024-11-12 14:31:07 -05:00
committed by GitHub
parent 759b6ed17a
commit 1538ee17f9
14 changed files with 534 additions and 542 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from functools import cache
@@ -364,17 +365,31 @@ class BaseLanguageModel(
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.
"""
if tools is not None:
warnings.warn(
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
stacklevel=2,
)
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
@classmethod

View File

@@ -1,5 +1,8 @@
import base64
import json
import typing
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import pytest
@@ -19,6 +22,7 @@ from langchain_core.messages.utils import (
merge_message_runs,
trim_messages,
)
from langchain_core.tools import BaseTool
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
@@ -431,7 +435,15 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
class FakeTokenCountingModel(FakeChatModel):
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
]
] = None,
) -> int:
return dummy_token_counter(messages)