feat(core): allow scaling by reported usage when counting tokens approximately (#34996)

This commit is contained in:
ccurme
2026-02-03 15:19:18 -05:00
committed by GitHub
parent 8072a51f39
commit 09654f4382
2 changed files with 137 additions and 1 deletions

View File

@@ -2188,6 +2188,7 @@ def count_tokens_approximately(
extra_tokens_per_message: float = 3.0,
count_name: bool = True,
tokens_per_image: int = 85,
use_usage_metadata_scaling: bool = False,
) -> int:
"""Approximate the total number of tokens in messages.
@@ -2211,6 +2212,11 @@ def count_tokens_approximately(
count_name: Whether to include message names in the count.
tokens_per_image: Fixed token cost per image (default: 85, aligned with
OpenAI's low-resolution image token cost).
use_usage_metadata_scaling: If True, and all AI messages have consistent
`response_metadata['model_provider']`, scale the approximate token count
using the **most recent** AI message that has
`usage_metadata['total_tokens']`. The scaling factor is:
`AI_total_tokens / approx_tokens_up_to_that_AI_message`
Returns:
Approximate number of tokens in the messages.
@@ -2225,8 +2231,16 @@ def count_tokens_approximately(
!!! version-added "Added in `langchain-core` 0.3.46"
"""
converted_messages = convert_to_messages(messages)
token_count = 0.0
for message in convert_to_messages(messages):
ai_model_provider: str | None = None
invalid_model_provider = False
last_ai_total_tokens: int | None = None
approx_at_last_ai: float | None = None
for message in converted_messages:
message_chars = 0
if isinstance(message.content, str):
@@ -2284,6 +2298,30 @@ def count_tokens_approximately(
# add extra tokens per message
token_count += extra_tokens_per_message
if use_usage_metadata_scaling and isinstance(message, AIMessage):
model_provider = message.response_metadata.get("model_provider")
if ai_model_provider is None:
ai_model_provider = model_provider
elif model_provider != ai_model_provider:
invalid_model_provider = True
if message.usage_metadata and isinstance(
(total_tokens := message.usage_metadata.get("total_tokens")), int
):
last_ai_total_tokens = total_tokens
approx_at_last_ai = token_count
if (
use_usage_metadata_scaling
and not invalid_model_provider
and ai_model_provider is not None
and last_ai_total_tokens is not None
and approx_at_last_ai
and approx_at_last_ai > 0
):
scale_factor = last_ai_total_tokens / approx_at_last_ai
token_count *= max(1.0, scale_factor)
# round up once more time in case extra_tokens_per_message is a float
return math.ceil(token_count)

View File

@@ -1,5 +1,6 @@
import base64
import json
import math
import re
from collections.abc import Callable, Sequence
from typing import Any, TypedDict
@@ -1594,6 +1595,103 @@ def test_count_tokens_approximately_mixed_content_types() -> None:
assert sum(count_tokens_approximately([m]) for m in messages) == token_count
def test_count_tokens_approximately_usage_metadata_scaling() -> None:
messages = [
HumanMessage("text"),
AIMessage(
"text",
response_metadata={"model_provider": "openai"},
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100},
),
HumanMessage("text"),
AIMessage(
"text",
response_metadata={"model_provider": "openai"},
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200},
),
]
unscaled = count_tokens_approximately(messages)
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
assert scaled == 200
assert unscaled < 100
messages.extend([ToolMessage("text", tool_call_id="abc123")] * 3)
unscaled_extended = count_tokens_approximately(messages)
scaled_extended = count_tokens_approximately(
messages, use_usage_metadata_scaling=True
)
# scaling should still be based on the most recent AIMessage with total_tokens=200
assert unscaled_extended > unscaled
assert scaled_extended > scaled
# And the scaled total should be the unscaled total multiplied by the same ratio.
# ratio = 200 / unscaled (as of last AI message)
expected_scaled_extended = math.ceil(unscaled_extended * (200 / unscaled))
assert scaled_extended == expected_scaled_extended
def test_count_tokens_approximately_usage_metadata_scaling_model_provider() -> None:
messages = [
HumanMessage("Hello"),
AIMessage(
"Hi",
response_metadata={"model_provider": "openai"},
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100},
),
HumanMessage("More text"),
AIMessage(
"More response",
response_metadata={"model_provider": "anthropic"},
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200},
),
]
unscaled = count_tokens_approximately(messages)
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
assert scaled == unscaled
def test_count_tokens_approximately_usage_metadata_scaling_total_tokens() -> None:
messages = [
HumanMessage("Hello"),
AIMessage(
"Hi",
response_metadata={"model_provider": "openai"},
# no usage metadata -> skip
),
]
unscaled = count_tokens_approximately(messages, chars_per_token=5)
scaled = count_tokens_approximately(
messages, chars_per_token=5, use_usage_metadata_scaling=True
)
assert scaled == unscaled
def test_count_tokens_approximately_usage_metadata_scaling_floor_at_one() -> None:
messages = [
HumanMessage("text"),
AIMessage(
"text",
response_metadata={"model_provider": "openai"},
# Set total_tokens lower than the approximate count up through this message.
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 1},
),
HumanMessage("text"),
]
unscaled = count_tokens_approximately(messages)
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
# scale factor would be < 1, but we floor it at 1.0 to avoid decreasing counts
assert scaled == unscaled
def test_get_buffer_string_with_structured_content() -> None:
"""Test get_buffer_string with structured content in messages."""
messages = [