mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
implement ChatAnthropic.get_num_tokens_from_messages
This commit is contained in:
parent
ff2ef48b35
commit
caed4e4ce8
@ -21,7 +21,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import anthropic
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core._api import beta, deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -1113,6 +1113,34 @@ class ChatAnthropic(BaseChatModel):
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
@beta()
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Count tokens in a sequence of input messages."""
|
||||
if any(
|
||||
isinstance(tool, ToolMessage)
|
||||
or (isinstance(tool, AIMessage) and tool.tool_calls)
|
||||
for tool in messages
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"get_num_tokens_from_messages does not yet support counting tokens "
|
||||
"in tool calls."
|
||||
)
|
||||
system, messages = _format_messages(messages)
|
||||
if isinstance(system, str):
|
||||
response = self._client.beta.messages.count_tokens(
|
||||
betas=["token-counting-2024-11-01"],
|
||||
model=self.model,
|
||||
system=system,
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
response = self._client.beta.messages.count_tokens(
|
||||
betas=["token-counting-2024-11-01"],
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
)
|
||||
return response.input_tokens
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
"""Anthropic tool definition."""
|
||||
|
@ -334,6 +334,8 @@ def test_anthropic_multimodal() -> None:
|
||||
response = chat.invoke(messages)
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
||||
num_tokens = chat.get_num_tokens_from_messages(messages)
|
||||
assert num_tokens > 0
|
||||
|
||||
|
||||
def test_streaming() -> None:
|
||||
@ -505,6 +507,43 @@ def test_with_structured_output() -> None:
|
||||
assert response["location"]
|
||||
|
||||
|
||||
def test_get_num_tokens_from_messages() -> None:
|
||||
llm = ChatAnthropic(model="claude-3-5-haiku-20241022") # type: ignore[call-arg]
|
||||
|
||||
# Test simple case
|
||||
messages = [
|
||||
SystemMessage(content="You are an assistant."),
|
||||
HumanMessage(content="What is the weather in SF?"),
|
||||
]
|
||||
num_tokens = llm.get_num_tokens_from_messages(messages)
|
||||
assert num_tokens > 0
|
||||
|
||||
# Test tool use (not yet supported)
|
||||
messages = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{"text": "Let's see.", "type": "text"},
|
||||
{
|
||||
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
|
||||
"input": {"location": "SF"},
|
||||
"name": "get_weather",
|
||||
"type": "tool_use",
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"args": {"location": "SF"},
|
||||
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
with pytest.raises(NotImplementedError):
|
||||
num_tokens = llm.get_num_tokens_from_messages(messages)
|
||||
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
"""Get the current weather in a given location"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user