feat(core): count tokens from tool schemas in count_tokens_approximately (#35098)

This commit is contained in:
ccurme
2026-02-09 14:10:44 -05:00
committed by GitHub
parent c5aee74614
commit e8e47b083e
2 changed files with 68 additions and 2 deletions

View File

@@ -47,11 +47,13 @@ from langchain_core.messages.human import HumanMessage, HumanMessageChunk
from langchain_core.messages.modifier import RemoveMessage
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
from langchain_core.utils.function_calling import convert_to_openai_tool
if TYPE_CHECKING:
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables.base import Runnable
from langchain_core.tools import BaseTool
try:
from langchain_text_splitters import TextSplitter
@@ -2189,6 +2191,7 @@ def count_tokens_approximately(
count_name: bool = True,
tokens_per_image: int = 85,
use_usage_metadata_scaling: bool = False,
tools: list[BaseTool | dict[str, Any]] | None = None,
) -> int:
"""Approximate the total number of tokens in messages.
@@ -2198,6 +2201,7 @@ def count_tokens_approximately(
- For tool messages, the token count also includes the tool call ID.
- For multimodal messages with images, applies a fixed token penalty per image
instead of counting base64-encoded characters.
- If tools are provided, the token count also includes stringified tool schemas.
Args:
messages: List of messages to count tokens for.
@@ -2217,9 +2221,12 @@ def count_tokens_approximately(
using the **most recent** AI message that has
`usage_metadata['total_tokens']`. The scaling factor is:
`AI_total_tokens / approx_tokens_up_to_that_AI_message`
tools: List of tools to include in the token count. Each tool can be either
a `BaseTool` instance or a dict representing a tool schema. `BaseTool`
instances are converted to OpenAI tool format before counting.
Returns:
Approximate number of tokens in the messages.
Approximate number of tokens in the messages (and tools, if provided).
Note:
This is a simple approximation that may not match the exact token count used by
@@ -2240,6 +2247,14 @@ def count_tokens_approximately(
last_ai_total_tokens: int | None = None
approx_at_last_ai: float | None = None
# Count tokens for tools if provided
if tools:
tools_chars = 0
for tool in tools:
tool_dict = tool if isinstance(tool, dict) else convert_to_openai_tool(tool)
tools_chars += len(json.dumps(tool_dict))
token_count += math.ceil(tools_chars / chars_per_token)
for message in converted_messages:
message_chars = 0
@@ -2313,6 +2328,7 @@ def count_tokens_approximately(
if (
use_usage_metadata_scaling
and len(converted_messages) > 1
and not invalid_model_provider
and ai_model_provider is not None
and last_ai_total_tokens is not None

View File

@@ -29,7 +29,7 @@ from langchain_core.messages.utils import (
merge_message_runs,
trim_messages,
)
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, tool
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
@@ -2908,3 +2908,53 @@ def test_count_tokens_approximately_respects_count_name_flag() -> None:
# When count_name is True, the name should contribute to the token count.
assert with_name > without_name
def test_count_tokens_approximately_with_tools() -> None:
"""Test that tools parameter adds to token count."""
messages = [HumanMessage(content="Hello")]
base_count = count_tokens_approximately(messages)
# Test with a BaseTool instance
@tool
def get_weather(location: str) -> str:
"""Get the weather for a location."""
return f"Weather in {location}"
count_with_tool = count_tokens_approximately(messages, tools=[get_weather])
assert count_with_tool > base_count
# Test with a dict tool schema
tool_schema = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather for a location.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
},
}
count_with_dict_tool = count_tokens_approximately(messages, tools=[tool_schema])
assert count_with_dict_tool > base_count
# Test with multiple tools
@tool
def get_time(timezone: str) -> str:
"""Get the current time in a timezone."""
return f"Time in {timezone}"
count_with_multiple = count_tokens_approximately(
messages, tools=[get_weather, get_time]
)
assert count_with_multiple > count_with_tool
# Test with no tools (None) should equal base count
count_no_tools = count_tokens_approximately(messages, tools=None)
assert count_no_tools == base_count
# Test with empty tools list should equal base count
count_empty_tools = count_tokens_approximately(messages, tools=[])
assert count_empty_tools == base_count