implement ChatAnthropic.get_num_tokens_from_messages

This commit is contained in:
Chester Curme 2024-11-05 11:09:07 -05:00
parent ff2ef48b35
commit caed4e4ce8
2 changed files with 68 additions and 1 deletions

View File

@ -21,7 +21,7 @@ from typing import (
) )
import anthropic import anthropic
from langchain_core._api import deprecated from langchain_core._api import beta, deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
@ -1113,6 +1113,34 @@ class ChatAnthropic(BaseChatModel):
else: else:
return llm | output_parser return llm | output_parser
@beta()
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Count tokens in a sequence of input messages."""
if any(
isinstance(tool, ToolMessage)
or (isinstance(tool, AIMessage) and tool.tool_calls)
for tool in messages
):
raise NotImplementedError(
"get_num_tokens_from_messages does not yet support counting tokens "
"in tool calls."
)
system, messages = _format_messages(messages)
if isinstance(system, str):
response = self._client.beta.messages.count_tokens(
betas=["token-counting-2024-11-01"],
model=self.model,
system=system,
messages=messages,
)
else:
response = self._client.beta.messages.count_tokens(
betas=["token-counting-2024-11-01"],
model=self.model,
messages=messages,
)
return response.input_tokens
class AnthropicTool(TypedDict): class AnthropicTool(TypedDict):
"""Anthropic tool definition.""" """Anthropic tool definition."""

View File

@ -334,6 +334,8 @@ def test_anthropic_multimodal() -> None:
response = chat.invoke(messages) response = chat.invoke(messages)
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)
num_tokens = chat.get_num_tokens_from_messages(messages)
assert num_tokens > 0
def test_streaming() -> None: def test_streaming() -> None:
@ -505,6 +507,43 @@ def test_with_structured_output() -> None:
assert response["location"] assert response["location"]
def test_get_num_tokens_from_messages() -> None:
llm = ChatAnthropic(model="claude-3-5-haiku-20241022") # type: ignore[call-arg]
# Test simple case
messages = [
SystemMessage(content="You are an assistant."),
HumanMessage(content="What is the weather in SF?"),
]
num_tokens = llm.get_num_tokens_from_messages(messages)
assert num_tokens > 0
# Test tool use (not yet supported)
messages = [
AIMessage(
content=[
{"text": "Let's see.", "type": "text"},
{
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
"input": {"location": "SF"},
"name": "get_weather",
"type": "tool_use",
},
],
tool_calls=[
{
"name": "get_weather",
"args": {"location": "SF"},
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
"type": "tool_call",
},
],
)
]
with pytest.raises(NotImplementedError):
num_tokens = llm.get_num_tokens_from_messages(messages)
class GetWeather(BaseModel): class GetWeather(BaseModel):
"""Get the current weather in a given location""" """Get the current weather in a given location"""