mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 14:23:58 +00:00
update anthropic
This commit is contained in:
parent
e4bfc84d6e
commit
db7b518308
@ -1114,37 +1114,38 @@ class ChatAnthropic(BaseChatModel):
|
||||
return llm | output_parser
|
||||
|
||||
@beta()
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
tools: Optional[
|
||||
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||
] = None,
|
||||
) -> int:
|
||||
"""Count tokens in a sequence of input messages.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||
to be converted to tool schemas.
|
||||
|
||||
.. versionchanged:: 0.2.5
|
||||
|
||||
Uses Anthropic's token counting API to count tokens in messages. See:
|
||||
https://docs.anthropic.com/en/api/messages-count-tokens
|
||||
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
|
||||
"""
|
||||
if any(
|
||||
isinstance(tool, ToolMessage)
|
||||
or (isinstance(tool, AIMessage) and tool.tool_calls)
|
||||
for tool in messages
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"get_num_tokens_from_messages does not yet support counting tokens "
|
||||
"in tool calls."
|
||||
)
|
||||
formatted_system, formatted_messages = _format_messages(messages)
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if isinstance(formatted_system, str):
|
||||
response = self._client.beta.messages.count_tokens(
|
||||
betas=["token-counting-2024-11-01"],
|
||||
model=self.model,
|
||||
system=formatted_system,
|
||||
messages=formatted_messages, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
response = self._client.beta.messages.count_tokens(
|
||||
betas=["token-counting-2024-11-01"],
|
||||
model=self.model,
|
||||
messages=formatted_messages, # type: ignore[arg-type]
|
||||
)
|
||||
kwargs["system"] = formatted_system
|
||||
if tools:
|
||||
kwargs["tools"] = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||
|
||||
response = self._client.beta.messages.count_tokens(
|
||||
betas=["token-counting-2024-11-01"],
|
||||
model=self.model,
|
||||
messages=formatted_messages, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
)
|
||||
return response.input_tokens
|
||||
|
||||
|
||||
|
@ -508,18 +508,34 @@ def test_with_structured_output() -> None:
|
||||
|
||||
|
||||
def test_get_num_tokens_from_messages() -> None:
|
||||
llm = ChatAnthropic(model="claude-3-5-haiku-20241022") # type: ignore[call-arg]
|
||||
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022") # type: ignore[call-arg]
|
||||
|
||||
# Test simple case
|
||||
messages = [
|
||||
SystemMessage(content="You are an assistant."),
|
||||
HumanMessage(content="What is the weather in SF?"),
|
||||
SystemMessage(content="You are a scientist"),
|
||||
HumanMessage(content="Hello, Claude"),
|
||||
]
|
||||
num_tokens = llm.get_num_tokens_from_messages(messages)
|
||||
assert num_tokens > 0
|
||||
|
||||
# Test tool use (not yet supported)
|
||||
# Test tool use
|
||||
@tool(parse_docstring=True)
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the current weather in a given location
|
||||
|
||||
Args:
|
||||
location: The city and state, e.g. San Francisco, CA
|
||||
"""
|
||||
return "Sunny"
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="What's the weather like in San Francisco?"),
|
||||
]
|
||||
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
|
||||
assert num_tokens > 0
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="What's the weather like in San Francisco?"),
|
||||
AIMessage(
|
||||
content=[
|
||||
{"text": "Let's see.", "type": "text"},
|
||||
@ -538,10 +554,11 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
)
|
||||
),
|
||||
ToolMessage(content="Sunny", tool_call_id="toolu_01V6d6W32QGGSmQm4BT98EKk"),
|
||||
]
|
||||
with pytest.raises(NotImplementedError):
|
||||
num_tokens = llm.get_num_tokens_from_messages(messages)
|
||||
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
|
||||
assert num_tokens > 0
|
||||
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
|
Loading…
Reference in New Issue
Block a user