mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-11 16:01:33 +00:00
fix(standard-tests): ensure non-negative token counts in usage metadata assertions (#32593)
This commit is contained in:
@@ -1098,26 +1098,36 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
)
|
||||
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_cache_read_input()
|
||||
assert (usage_metadata := msg.usage_metadata) is not None
|
||||
assert (
|
||||
input_token_details := usage_metadata.get("input_token_details")
|
||||
) is not None
|
||||
assert isinstance(input_token_details.get("cache_read"), int)
|
||||
usage_metadata = msg.usage_metadata
|
||||
assert usage_metadata is not None
|
||||
input_token_details = usage_metadata.get("input_token_details")
|
||||
assert input_token_details is not None
|
||||
cache_read_tokens = input_token_details.get("cache_read")
|
||||
assert isinstance(cache_read_tokens, int)
|
||||
assert cache_read_tokens >= 0
|
||||
# Asserts that total input tokens are at least the sum of the token counts
|
||||
assert usage_metadata.get("input_tokens", 0) >= sum(
|
||||
v for v in input_token_details.values() if isinstance(v, int)
|
||||
total_detailed_tokens = sum(
|
||||
v for v in input_token_details.values() if isinstance(v, int) and v >= 0
|
||||
)
|
||||
input_tokens = usage_metadata.get("input_tokens", 0)
|
||||
assert isinstance(input_tokens, int)
|
||||
assert input_tokens >= total_detailed_tokens
|
||||
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_cache_creation_input()
|
||||
assert (usage_metadata := msg.usage_metadata) is not None
|
||||
assert (
|
||||
input_token_details := usage_metadata.get("input_token_details")
|
||||
) is not None
|
||||
assert isinstance(input_token_details.get("cache_creation"), int)
|
||||
usage_metadata = msg.usage_metadata
|
||||
assert usage_metadata is not None
|
||||
input_token_details = usage_metadata.get("input_token_details")
|
||||
assert input_token_details is not None
|
||||
cache_creation_tokens = input_token_details.get("cache_creation")
|
||||
assert isinstance(cache_creation_tokens, int)
|
||||
assert cache_creation_tokens >= 0
|
||||
# Asserts that total input tokens are at least the sum of the token counts
|
||||
assert usage_metadata.get("input_tokens", 0) >= sum(
|
||||
v for v in input_token_details.values() if isinstance(v, int)
|
||||
total_detailed_tokens = sum(
|
||||
v for v in input_token_details.values() if isinstance(v, int) and v >= 0
|
||||
)
|
||||
input_tokens = usage_metadata.get("input_tokens", 0)
|
||||
assert isinstance(input_tokens, int)
|
||||
assert input_tokens >= total_detailed_tokens
|
||||
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
"""Test usage metadata in streaming mode.
|
||||
|
Reference in New Issue
Block a user