diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index baa55c24332..9ac20f0df22 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -2188,6 +2188,7 @@ def count_tokens_approximately( extra_tokens_per_message: float = 3.0, count_name: bool = True, tokens_per_image: int = 85, + use_usage_metadata_scaling: bool = False, ) -> int: """Approximate the total number of tokens in messages. @@ -2211,6 +2212,11 @@ def count_tokens_approximately( count_name: Whether to include message names in the count. tokens_per_image: Fixed token cost per image (default: 85, aligned with OpenAI's low-resolution image token cost). + use_usage_metadata_scaling: If True, and all AI messages have consistent + `response_metadata['model_provider']`, scale the approximate token count + using the **most recent** AI message that has + `usage_metadata['total_tokens']`. The scaling factor is: + `AI_total_tokens / approx_tokens_up_to_that_AI_message` Returns: Approximate number of tokens in the messages. @@ -2225,8 +2231,16 @@ def count_tokens_approximately( !!! version-added "Added in `langchain-core` 0.3.46" """ + converted_messages = convert_to_messages(messages) + token_count = 0.0 - for message in convert_to_messages(messages): + + ai_model_provider: str | None = None + invalid_model_provider = False + last_ai_total_tokens: int | None = None + approx_at_last_ai: float | None = None + + for message in converted_messages: message_chars = 0 if isinstance(message.content, str): @@ -2284,6 +2298,30 @@ def count_tokens_approximately( # add extra tokens per message token_count += extra_tokens_per_message + if use_usage_metadata_scaling and isinstance(message, AIMessage): + model_provider = message.response_metadata.get("model_provider") + if ai_model_provider is None: + ai_model_provider = model_provider + elif model_provider != ai_model_provider: + invalid_model_provider = True + + if message.usage_metadata and isinstance( + (total_tokens := message.usage_metadata.get("total_tokens")), int + ): + last_ai_total_tokens = total_tokens + approx_at_last_ai = token_count + + if ( + use_usage_metadata_scaling + and not invalid_model_provider + and ai_model_provider is not None + and last_ai_total_tokens is not None + and approx_at_last_ai + and approx_at_last_ai > 0 + ): + scale_factor = last_ai_total_tokens / approx_at_last_ai + token_count *= max(1.0, scale_factor) + # round up once more time in case extra_tokens_per_message is a float return math.ceil(token_count) diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index faa2cf2f943..91060bb8dc9 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1,5 +1,6 @@ import base64 import json +import math import re from collections.abc import Callable, Sequence from typing import Any, TypedDict @@ -1594,6 +1595,103 @@ def test_count_tokens_approximately_mixed_content_types() -> None: assert sum(count_tokens_approximately([m]) for m in messages) == token_count +def test_count_tokens_approximately_usage_metadata_scaling() -> None: + messages = [ + HumanMessage("text"), + AIMessage( + "text", + response_metadata={"model_provider": "openai"}, + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100}, + ), + HumanMessage("text"), + AIMessage( + "text", + response_metadata={"model_provider": "openai"}, + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200}, + ), + ] + + unscaled = count_tokens_approximately(messages) + scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True) + + assert scaled == 200 + assert unscaled < 100 + + messages.extend([ToolMessage("text", tool_call_id="abc123")] * 3) + + unscaled_extended = count_tokens_approximately(messages) + scaled_extended = count_tokens_approximately( + messages, use_usage_metadata_scaling=True + ) + + # scaling should still be based on the most recent AIMessage with total_tokens=200 + assert unscaled_extended > unscaled + assert scaled_extended > scaled + + # And the scaled total should be the unscaled total multiplied by the same ratio. + # ratio = 200 / unscaled (as of last AI message) + expected_scaled_extended = math.ceil(unscaled_extended * (200 / unscaled)) + assert scaled_extended == expected_scaled_extended + + +def test_count_tokens_approximately_usage_metadata_scaling_model_provider() -> None: + messages = [ + HumanMessage("Hello"), + AIMessage( + "Hi", + response_metadata={"model_provider": "openai"}, + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 100}, + ), + HumanMessage("More text"), + AIMessage( + "More response", + response_metadata={"model_provider": "anthropic"}, + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 200}, + ), + ] + + unscaled = count_tokens_approximately(messages) + scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True) + assert scaled == unscaled + + +def test_count_tokens_approximately_usage_metadata_scaling_total_tokens() -> None: + messages = [ + HumanMessage("Hello"), + AIMessage( + "Hi", + response_metadata={"model_provider": "openai"}, + # no usage metadata -> skip + ), + ] + + unscaled = count_tokens_approximately(messages, chars_per_token=5) + scaled = count_tokens_approximately( + messages, chars_per_token=5, use_usage_metadata_scaling=True + ) + + assert scaled == unscaled + + +def test_count_tokens_approximately_usage_metadata_scaling_floor_at_one() -> None: + messages = [ + HumanMessage("text"), + AIMessage( + "text", + response_metadata={"model_provider": "openai"}, + # Set total_tokens lower than the approximate count up through this message. + usage_metadata={"input_tokens": 0, "output_tokens": 0, "total_tokens": 1}, + ), + HumanMessage("text"), + ] + + unscaled = count_tokens_approximately(messages) + scaled = count_tokens_approximately(messages, use_usage_metadata_scaling=True) + + # scale factor would be < 1, but we floor it at 1.0 to avoid decreasing counts + assert scaled == unscaled + + def test_get_buffer_string_with_structured_content() -> None: """Test get_buffer_string with structured content in messages.""" messages = [