From caed4e4ce87e01a724b4f7b9ade90532412cdc68 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 5 Nov 2024 11:09:07 -0500 Subject: [PATCH] implement ChatAnthropic.get_num_tokens_from_messages --- .../langchain_anthropic/chat_models.py | 30 +++++++++++++- .../integration_tests/test_chat_models.py | 39 +++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 5a6ca4cc6b9..9e428296ad6 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -21,7 +21,7 @@ from typing import ( ) import anthropic -from langchain_core._api import deprecated +from langchain_core._api import beta, deprecated from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -1113,6 +1113,34 @@ class ChatAnthropic(BaseChatModel): else: return llm | output_parser + @beta() + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Count tokens in a sequence of input messages.""" + if any( + isinstance(tool, ToolMessage) + or (isinstance(tool, AIMessage) and tool.tool_calls) + for tool in messages + ): + raise NotImplementedError( + "get_num_tokens_from_messages does not yet support counting tokens " + "in tool calls." + ) + system, messages = _format_messages(messages) + if isinstance(system, str): + response = self._client.beta.messages.count_tokens( + betas=["token-counting-2024-11-01"], + model=self.model, + system=system, + messages=messages, + ) + else: + response = self._client.beta.messages.count_tokens( + betas=["token-counting-2024-11-01"], + model=self.model, + messages=messages, + ) + return response.input_tokens + class AnthropicTool(TypedDict): """Anthropic tool definition.""" diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 037caf718da..5a7a1f780b9 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -334,6 +334,8 @@ def test_anthropic_multimodal() -> None: response = chat.invoke(messages) assert isinstance(response, AIMessage) assert isinstance(response.content, str) + num_tokens = chat.get_num_tokens_from_messages(messages) + assert num_tokens > 0 def test_streaming() -> None: @@ -505,6 +507,43 @@ def test_with_structured_output() -> None: assert response["location"] +def test_get_num_tokens_from_messages() -> None: + llm = ChatAnthropic(model="claude-3-5-haiku-20241022") # type: ignore[call-arg] + + # Test simple case + messages = [ + SystemMessage(content="You are an assistant."), + HumanMessage(content="What is the weather in SF?"), + ] + num_tokens = llm.get_num_tokens_from_messages(messages) + assert num_tokens > 0 + + # Test tool use (not yet supported) + messages = [ + AIMessage( + content=[ + {"text": "Let's see.", "type": "text"}, + { + "id": "toolu_01V6d6W32QGGSmQm4BT98EKk", + "input": {"location": "SF"}, + "name": "get_weather", + "type": "tool_use", + }, + ], + tool_calls=[ + { + "name": "get_weather", + "args": {"location": "SF"}, + "id": "toolu_01V6d6W32QGGSmQm4BT98EKk", + "type": "tool_call", + }, + ], + ) + ] + with pytest.raises(NotImplementedError): + num_tokens = llm.get_num_tokens_from_messages(messages) + + class GetWeather(BaseModel): """Get the current weather in a given location"""