mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
feat(core): allow scaling by reported usage when counting tokens approximately (#34996)
This commit is contained in:
@@ -2188,6 +2188,7 @@ def count_tokens_approximately(
|
||||
extra_tokens_per_message: float = 3.0,
|
||||
count_name: bool = True,
|
||||
tokens_per_image: int = 85,
|
||||
use_usage_metadata_scaling: bool = False,
|
||||
) -> int:
|
||||
"""Approximate the total number of tokens in messages.
|
||||
|
||||
@@ -2211,6 +2212,11 @@ def count_tokens_approximately(
|
||||
count_name: Whether to include message names in the count.
|
||||
tokens_per_image: Fixed token cost per image (default: 85, aligned with
|
||||
OpenAI's low-resolution image token cost).
|
||||
use_usage_metadata_scaling: If True, and all AI messages have consistent
|
||||
`response_metadata['model_provider']`, scale the approximate token count
|
||||
using the **most recent** AI message that has
|
||||
`usage_metadata['total_tokens']`. The scaling factor is:
|
||||
`AI_total_tokens / approx_tokens_up_to_that_AI_message`
|
||||
|
||||
Returns:
|
||||
Approximate number of tokens in the messages.
|
||||
@@ -2225,8 +2231,16 @@ def count_tokens_approximately(
|
||||
|
||||
!!! version-added "Added in `langchain-core` 0.3.46"
|
||||
"""
|
||||
converted_messages = convert_to_messages(messages)
|
||||
|
||||
token_count = 0.0
|
||||
for message in convert_to_messages(messages):
|
||||
|
||||
ai_model_provider: str | None = None
|
||||
invalid_model_provider = False
|
||||
last_ai_total_tokens: int | None = None
|
||||
approx_at_last_ai: float | None = None
|
||||
|
||||
for message in converted_messages:
|
||||
message_chars = 0
|
||||
|
||||
if isinstance(message.content, str):
|
||||
@@ -2284,6 +2298,30 @@ def count_tokens_approximately(
|
||||
# add extra tokens per message
|
||||
token_count += extra_tokens_per_message
|
||||
|
||||
if use_usage_metadata_scaling and isinstance(message, AIMessage):
|
||||
model_provider = message.response_metadata.get("model_provider")
|
||||
if ai_model_provider is None:
|
||||
ai_model_provider = model_provider
|
||||
elif model_provider != ai_model_provider:
|
||||
invalid_model_provider = True
|
||||
|
||||
if message.usage_metadata and isinstance(
|
||||
(total_tokens := message.usage_metadata.get("total_tokens")), int
|
||||
):
|
||||
last_ai_total_tokens = total_tokens
|
||||
approx_at_last_ai = token_count
|
||||
|
||||
if (
|
||||
use_usage_metadata_scaling
|
||||
and not invalid_model_provider
|
||||
and ai_model_provider is not None
|
||||
and last_ai_total_tokens is not None
|
||||
and approx_at_last_ai
|
||||
and approx_at_last_ai > 0
|
||||
):
|
||||
scale_factor = last_ai_total_tokens / approx_at_last_ai
|
||||
token_count *= max(1.0, scale_factor)
|
||||
|
||||
# round up once more time in case extra_tokens_per_message is a float
|
||||
return math.ceil(token_count)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypedDict
|
||||
@@ -1594,6 +1595,103 @@ def test_count_tokens_approximately_mixed_content_types() -> None:
|
||||
assert sum(count_tokens_approximately([m]) for m in messages) == token_count
|
||||
|
||||
|
||||
def test_count_tokens_approximately_usage_metadata_scaling() -> None:
|
||||
messages = [
|
||||
HumanMessage("text"),
|
||||
AIMessage(
|
||||
"text",
|
||||
response_metadata={"model_provider": "openai"},
|
||||
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100},
|
||||
),
|
||||
HumanMessage("text"),
|
||||
AIMessage(
|
||||
"text",
|
||||
response_metadata={"model_provider": "openai"},
|
||||
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200},
|
||||
),
|
||||
]
|
||||
|
||||
unscaled = count_tokens_approximately(messages)
|
||||
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
|
||||
|
||||
assert scaled == 200
|
||||
assert unscaled < 100
|
||||
|
||||
messages.extend([ToolMessage("text", tool_call_id="abc123")] * 3)
|
||||
|
||||
unscaled_extended = count_tokens_approximately(messages)
|
||||
scaled_extended = count_tokens_approximately(
|
||||
messages, use_usage_metadata_scaling=True
|
||||
)
|
||||
|
||||
# scaling should still be based on the most recent AIMessage with total_tokens=200
|
||||
assert unscaled_extended > unscaled
|
||||
assert scaled_extended > scaled
|
||||
|
||||
# And the scaled total should be the unscaled total multiplied by the same ratio.
|
||||
# ratio = 200 / unscaled (as of last AI message)
|
||||
expected_scaled_extended = math.ceil(unscaled_extended * (200 / unscaled))
|
||||
assert scaled_extended == expected_scaled_extended
|
||||
|
||||
|
||||
def test_count_tokens_approximately_usage_metadata_scaling_model_provider() -> None:
|
||||
messages = [
|
||||
HumanMessage("Hello"),
|
||||
AIMessage(
|
||||
"Hi",
|
||||
response_metadata={"model_provider": "openai"},
|
||||
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100},
|
||||
),
|
||||
HumanMessage("More text"),
|
||||
AIMessage(
|
||||
"More response",
|
||||
response_metadata={"model_provider": "anthropic"},
|
||||
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200},
|
||||
),
|
||||
]
|
||||
|
||||
unscaled = count_tokens_approximately(messages)
|
||||
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
|
||||
assert scaled == unscaled
|
||||
|
||||
|
||||
def test_count_tokens_approximately_usage_metadata_scaling_total_tokens() -> None:
|
||||
messages = [
|
||||
HumanMessage("Hello"),
|
||||
AIMessage(
|
||||
"Hi",
|
||||
response_metadata={"model_provider": "openai"},
|
||||
# no usage metadata -> skip
|
||||
),
|
||||
]
|
||||
|
||||
unscaled = count_tokens_approximately(messages, chars_per_token=5)
|
||||
scaled = count_tokens_approximately(
|
||||
messages, chars_per_token=5, use_usage_metadata_scaling=True
|
||||
)
|
||||
|
||||
assert scaled == unscaled
|
||||
|
||||
|
||||
def test_count_tokens_approximately_usage_metadata_scaling_floor_at_one() -> None:
|
||||
messages = [
|
||||
HumanMessage("text"),
|
||||
AIMessage(
|
||||
"text",
|
||||
response_metadata={"model_provider": "openai"},
|
||||
# Set total_tokens lower than the approximate count up through this message.
|
||||
usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 1},
|
||||
),
|
||||
HumanMessage("text"),
|
||||
]
|
||||
|
||||
unscaled = count_tokens_approximately(messages)
|
||||
scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True)
|
||||
|
||||
# scale factor would be < 1, but we floor it at 1.0 to avoid decreasing counts
|
||||
assert scaled == unscaled
|
||||
|
||||
|
||||
def test_get_buffer_string_with_structured_content() -> None:
|
||||
"""Test get_buffer_string with structured content in messages."""
|
||||
messages = [
|
||||
|
||||
Reference in New Issue
Block a user