fix(standard-tests): ensure non-negative token counts in usage metadata assertions (#32593)

This commit is contained in:
Mason Daugherty
2025-09-08 16:49:26 -04:00
committed by GitHub
parent 8b90eae455
commit 35e9d36b0e

View File

@@ -1098,26 +1098,36 @@ class ChatModelIntegrationTests(ChatModelTests):
) )
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]: if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_read_input() msg = self.invoke_with_cache_read_input()
assert (usage_metadata := msg.usage_metadata) is not None usage_metadata = msg.usage_metadata
assert ( assert usage_metadata is not None
input_token_details := usage_metadata.get("input_token_details") input_token_details = usage_metadata.get("input_token_details")
) is not None assert input_token_details is not None
assert isinstance(input_token_details.get("cache_read"), int) cache_read_tokens = input_token_details.get("cache_read")
assert isinstance(cache_read_tokens, int)
assert cache_read_tokens >= 0
# Asserts that total input tokens are at least the sum of the token counts # Asserts that total input tokens are at least the sum of the token counts
assert usage_metadata.get("input_tokens", 0) >= sum( total_detailed_tokens = sum(
v for v in input_token_details.values() if isinstance(v, int) v for v in input_token_details.values() if isinstance(v, int) and v >= 0
) )
input_tokens = usage_metadata.get("input_tokens", 0)
assert isinstance(input_tokens, int)
assert input_tokens >= total_detailed_tokens
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]: if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_creation_input() msg = self.invoke_with_cache_creation_input()
assert (usage_metadata := msg.usage_metadata) is not None usage_metadata = msg.usage_metadata
assert ( assert usage_metadata is not None
input_token_details := usage_metadata.get("input_token_details") input_token_details = usage_metadata.get("input_token_details")
) is not None assert input_token_details is not None
assert isinstance(input_token_details.get("cache_creation"), int) cache_creation_tokens = input_token_details.get("cache_creation")
assert isinstance(cache_creation_tokens, int)
assert cache_creation_tokens >= 0
# Asserts that total input tokens are at least the sum of the token counts # Asserts that total input tokens are at least the sum of the token counts
assert usage_metadata.get("input_tokens", 0) >= sum( total_detailed_tokens = sum(
v for v in input_token_details.values() if isinstance(v, int) v for v in input_token_details.values() if isinstance(v, int) and v >= 0
) )
input_tokens = usage_metadata.get("input_tokens", 0)
assert isinstance(input_tokens, int)
assert input_tokens >= total_detailed_tokens
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
"""Test usage metadata in streaming mode. """Test usage metadata in streaming mode.