diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index b75d4d2e7f0..f3633bd58b0 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -1098,26 +1098,36 @@ class ChatModelIntegrationTests(ChatModelTests): ) if "cache_read_input" in self.supported_usage_metadata_details["invoke"]: msg = self.invoke_with_cache_read_input() - assert (usage_metadata := msg.usage_metadata) is not None - assert ( - input_token_details := usage_metadata.get("input_token_details") - ) is not None - assert isinstance(input_token_details.get("cache_read"), int) + usage_metadata = msg.usage_metadata + assert usage_metadata is not None + input_token_details = usage_metadata.get("input_token_details") + assert input_token_details is not None + cache_read_tokens = input_token_details.get("cache_read") + assert isinstance(cache_read_tokens, int) + assert cache_read_tokens >= 0 # Asserts that total input tokens are at least the sum of the token counts - assert usage_metadata.get("input_tokens", 0) >= sum( - v for v in input_token_details.values() if isinstance(v, int) + total_detailed_tokens = sum( + v for v in input_token_details.values() if isinstance(v, int) and v >= 0 ) + input_tokens = usage_metadata.get("input_tokens", 0) + assert isinstance(input_tokens, int) + assert input_tokens >= total_detailed_tokens if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]: msg = self.invoke_with_cache_creation_input() - assert (usage_metadata := msg.usage_metadata) is not None - assert ( - input_token_details := usage_metadata.get("input_token_details") - ) is not None - assert isinstance(input_token_details.get("cache_creation"), int) + usage_metadata = msg.usage_metadata + assert usage_metadata is not None + input_token_details = usage_metadata.get("input_token_details") + assert input_token_details is not None + cache_creation_tokens = input_token_details.get("cache_creation") + assert isinstance(cache_creation_tokens, int) + assert cache_creation_tokens >= 0 # Asserts that total input tokens are at least the sum of the token counts - assert usage_metadata.get("input_tokens", 0) >= sum( - v for v in input_token_details.values() if isinstance(v, int) + total_detailed_tokens = sum( + v for v in input_token_details.values() if isinstance(v, int) and v >= 0 ) + input_tokens = usage_metadata.get("input_tokens", 0) + assert isinstance(input_tokens, int) + assert input_tokens >= total_detailed_tokens def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: """Test usage metadata in streaming mode.