mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
feat(core): count tokens from tool schemas in count_tokens_approximately (#35098)
This commit is contained in:
@@ -47,11 +47,13 @@ from langchain_core.messages.human import HumanMessage, HumanMessageChunk
|
||||
from langchain_core.messages.modifier import RemoveMessage
|
||||
from langchain_core.messages.system import SystemMessage, SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall, ToolMessage, ToolMessageChunk
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
try:
|
||||
from langchain_text_splitters import TextSplitter
|
||||
@@ -2189,6 +2191,7 @@ def count_tokens_approximately(
|
||||
count_name: bool = True,
|
||||
tokens_per_image: int = 85,
|
||||
use_usage_metadata_scaling: bool = False,
|
||||
tools: list[BaseTool | dict[str, Any]] | None = None,
|
||||
) -> int:
|
||||
"""Approximate the total number of tokens in messages.
|
||||
|
||||
@@ -2198,6 +2201,7 @@ def count_tokens_approximately(
|
||||
- For tool messages, the token count also includes the tool call ID.
|
||||
- For multimodal messages with images, applies a fixed token penalty per image
|
||||
instead of counting base64-encoded characters.
|
||||
- If tools are provided, the token count also includes stringified tool schemas.
|
||||
|
||||
Args:
|
||||
messages: List of messages to count tokens for.
|
||||
@@ -2217,9 +2221,12 @@ def count_tokens_approximately(
|
||||
using the **most recent** AI message that has
|
||||
`usage_metadata['total_tokens']`. The scaling factor is:
|
||||
`AI_total_tokens / approx_tokens_up_to_that_AI_message`
|
||||
tools: List of tools to include in the token count. Each tool can be either
|
||||
a `BaseTool` instance or a dict representing a tool schema. `BaseTool`
|
||||
instances are converted to OpenAI tool format before counting.
|
||||
|
||||
Returns:
|
||||
Approximate number of tokens in the messages.
|
||||
Approximate number of tokens in the messages (and tools, if provided).
|
||||
|
||||
Note:
|
||||
This is a simple approximation that may not match the exact token count used by
|
||||
@@ -2240,6 +2247,14 @@ def count_tokens_approximately(
|
||||
last_ai_total_tokens: int | None = None
|
||||
approx_at_last_ai: float | None = None
|
||||
|
||||
# Count tokens for tools if provided
|
||||
if tools:
|
||||
tools_chars = 0
|
||||
for tool in tools:
|
||||
tool_dict = tool if isinstance(tool, dict) else convert_to_openai_tool(tool)
|
||||
tools_chars += len(json.dumps(tool_dict))
|
||||
token_count += math.ceil(tools_chars / chars_per_token)
|
||||
|
||||
for message in converted_messages:
|
||||
message_chars = 0
|
||||
|
||||
@@ -2313,6 +2328,7 @@ def count_tokens_approximately(
|
||||
|
||||
if (
|
||||
use_usage_metadata_scaling
|
||||
and len(converted_messages) > 1
|
||||
and not invalid_model_provider
|
||||
and ai_model_provider is not None
|
||||
and last_ai_total_tokens is not None
|
||||
|
||||
@@ -29,7 +29,7 @@ from langchain_core.messages.utils import (
|
||||
merge_message_runs,
|
||||
trim_messages,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
|
||||
|
||||
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
|
||||
@@ -2908,3 +2908,53 @@ def test_count_tokens_approximately_respects_count_name_flag() -> None:
|
||||
|
||||
# When count_name is True, the name should contribute to the token count.
|
||||
assert with_name > without_name
|
||||
|
||||
|
||||
def test_count_tokens_approximately_with_tools() -> None:
|
||||
"""Test that tools parameter adds to token count."""
|
||||
messages = [HumanMessage(content="Hello")]
|
||||
base_count = count_tokens_approximately(messages)
|
||||
|
||||
# Test with a BaseTool instance
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the weather for a location."""
|
||||
return f"Weather in {location}"
|
||||
|
||||
count_with_tool = count_tokens_approximately(messages, tools=[get_weather])
|
||||
assert count_with_tool > base_count
|
||||
|
||||
# Test with a dict tool schema
|
||||
tool_schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather for a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
count_with_dict_tool = count_tokens_approximately(messages, tools=[tool_schema])
|
||||
assert count_with_dict_tool > base_count
|
||||
|
||||
# Test with multiple tools
|
||||
@tool
|
||||
def get_time(timezone: str) -> str:
|
||||
"""Get the current time in a timezone."""
|
||||
return f"Time in {timezone}"
|
||||
|
||||
count_with_multiple = count_tokens_approximately(
|
||||
messages, tools=[get_weather, get_time]
|
||||
)
|
||||
assert count_with_multiple > count_with_tool
|
||||
|
||||
# Test with no tools (None) should equal base count
|
||||
count_no_tools = count_tokens_approximately(messages, tools=None)
|
||||
assert count_no_tools == base_count
|
||||
|
||||
# Test with empty tools list should equal base count
|
||||
count_empty_tools = count_tokens_approximately(messages, tools=[])
|
||||
assert count_empty_tools == base_count
|
||||
|
||||
Reference in New Issue
Block a user