fix(standard-tests): ensure non-negative token counts in usage metadata assertions (#32593)

This commit is contained in:
Mason Daugherty
2025-09-08 16:49:26 -04:00
committed by GitHub
parent 8b90eae455
commit 35e9d36b0e

View File

@@ -1098,26 +1098,36 @@ class ChatModelIntegrationTests(ChatModelTests):
)
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_read_input()
assert (usage_metadata := msg.usage_metadata) is not None
assert (
input_token_details := usage_metadata.get("input_token_details")
) is not None
assert isinstance(input_token_details.get("cache_read"), int)
usage_metadata = msg.usage_metadata
assert usage_metadata is not None
input_token_details = usage_metadata.get("input_token_details")
assert input_token_details is not None
cache_read_tokens = input_token_details.get("cache_read")
assert isinstance(cache_read_tokens, int)
assert cache_read_tokens >= 0
# Asserts that total input tokens are at least the sum of the token counts
assert usage_metadata.get("input_tokens", 0) >= sum(
v for v in input_token_details.values() if isinstance(v, int)
total_detailed_tokens = sum(
v for v in input_token_details.values() if isinstance(v, int) and v >= 0
)
input_tokens = usage_metadata.get("input_tokens", 0)
assert isinstance(input_tokens, int)
assert input_tokens >= total_detailed_tokens
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_creation_input()
assert (usage_metadata := msg.usage_metadata) is not None
assert (
input_token_details := usage_metadata.get("input_token_details")
) is not None
assert isinstance(input_token_details.get("cache_creation"), int)
usage_metadata = msg.usage_metadata
assert usage_metadata is not None
input_token_details = usage_metadata.get("input_token_details")
assert input_token_details is not None
cache_creation_tokens = input_token_details.get("cache_creation")
assert isinstance(cache_creation_tokens, int)
assert cache_creation_tokens >= 0
# Asserts that total input tokens are at least the sum of the token counts
assert usage_metadata.get("input_tokens", 0) >= sum(
v for v in input_token_details.values() if isinstance(v, int)
total_detailed_tokens = sum(
v for v in input_token_details.values() if isinstance(v, int) and v >= 0
)
input_tokens = usage_metadata.get("input_tokens", 0)
assert isinstance(input_tokens, int)
assert input_tokens >= total_detailed_tokens
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
"""Test usage metadata in streaming mode.