fix(anthropic): Ignore general usage cache_creation fields if more specific fields are set (#35845)

Prevent double counting, since the sum of all `input_token_details`
should never exceed `input_tokens`

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Jacob Lee
2026-03-13 17:21:06 -07:00
committed by GitHub
parent b1f2d9c0fb
commit 6d6d7191cf
2 changed files with 136 additions and 1 deletions

View File

@@ -2181,18 +2181,27 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
# Currently just copying over the 5m and 1h keys, but if more are added in the
# future we'll need to expand this tuple
cache_creation_keys = ("ephemeral_5m_input_tokens", "ephemeral_1h_input_tokens")
specific_cache_creation_tokens = 0
if cache_creation:
if isinstance(cache_creation, BaseModel):
cache_creation = cache_creation.model_dump()
for k in cache_creation_keys:
specific_cache_creation_tokens += cache_creation.get(k, 0)
input_token_details[k] = cache_creation.get(k)
if not isinstance(specific_cache_creation_tokens, int):
specific_cache_creation_tokens = 0
if specific_cache_creation_tokens > 0:
# Remove generic key to avoid double counting cache creation tokens
input_token_details["cache_creation"] = 0
# Calculate total input tokens: Anthropic's `input_tokens` excludes cached tokens,
# so we need to add them back to get the true total input token count
input_tokens = (
(getattr(anthropic_usage, "input_tokens", 0) or 0) # Base input tokens
+ (input_token_details["cache_read"] or 0) # Tokens read from cache
+ (input_token_details["cache_creation"] or 0) # Tokens used to create cache
+ (
specific_cache_creation_tokens or input_token_details["cache_creation"] or 0
) # Tokens used to create cache
)
output_tokens = getattr(anthropic_usage, "output_tokens", 0) or 0

View File

@@ -1519,6 +1519,132 @@ def test_usage_metadata_standardization() -> None:
assert result["total_tokens"] == 0
def test_usage_metadata_cache_creation_ttl() -> None:
"""Test _create_usage_metadata with granular cache_creation TTL fields."""
# Case 1: cache_creation with specific ephemeral TTL tokens (BaseModel)
class CacheCreation(BaseModel):
ephemeral_5m_input_tokens: int = 100
ephemeral_1h_input_tokens: int = 50
class UsageWithCacheCreation(BaseModel):
input_tokens: int = 200
output_tokens: int = 30
cache_read_input_tokens: int = 10
cache_creation_input_tokens: int = 150
cache_creation: CacheCreation = CacheCreation()
result = _create_usage_metadata(UsageWithCacheCreation())
# input_tokens = 200 (base) + 10 (cache_read) + 150 (specific: 100+50)
assert result["input_tokens"] == 360
assert result["output_tokens"] == 30
assert result["total_tokens"] == 390
details = dict(result.get("input_token_details") or {})
assert details["cache_read"] == 10
# cache_creation should be suppressed to avoid double counting
assert details["cache_creation"] == 0
assert details["ephemeral_5m_input_tokens"] == 100
assert details["ephemeral_1h_input_tokens"] == 50
# Case 2: cache_creation as a dict
class UsageWithCacheCreationDict(BaseModel):
input_tokens: int = 200
output_tokens: int = 30
cache_read_input_tokens: int = 10
cache_creation_input_tokens: int = 150
cache_creation: dict = {
"ephemeral_5m_input_tokens": 80,
"ephemeral_1h_input_tokens": 70,
}
result = _create_usage_metadata(UsageWithCacheCreationDict())
assert result["input_tokens"] == 200 + 10 + 80 + 70
details = dict(result.get("input_token_details") or {})
assert details["cache_creation"] == 0
assert details["ephemeral_5m_input_tokens"] == 80
assert details["ephemeral_1h_input_tokens"] == 70
# Case 3: cache_creation exists but specific keys are zero — falls back to
# generic cache_creation_input_tokens
class CacheCreationZero(BaseModel):
ephemeral_5m_input_tokens: int = 0
ephemeral_1h_input_tokens: int = 0
class UsageWithCacheCreationZero(BaseModel):
input_tokens: int = 200
output_tokens: int = 30
cache_read_input_tokens: int = 10
cache_creation_input_tokens: int = 50
cache_creation: CacheCreationZero = CacheCreationZero()
result = _create_usage_metadata(UsageWithCacheCreationZero())
# specific_cache_creation_tokens = 0, so falls back to cache_creation_input_tokens
# input_tokens = 200 + 10 + 50 = 260
assert result["input_tokens"] == 260
assert result["output_tokens"] == 30
assert result["total_tokens"] == 290
details = dict(result.get("input_token_details") or {})
assert details["cache_read"] == 10
assert details["cache_creation"] == 50
# Case 4: cache_creation exists but specific keys are missing from the dict
class CacheCreationEmpty(BaseModel):
pass
class UsageWithCacheCreationEmpty(BaseModel):
input_tokens: int = 100
output_tokens: int = 20
cache_read_input_tokens: int = 5
cache_creation_input_tokens: int = 15
cache_creation: CacheCreationEmpty = CacheCreationEmpty()
result = _create_usage_metadata(UsageWithCacheCreationEmpty())
# specific_cache_creation_tokens = 0, falls back to cache_creation_input_tokens
assert result["input_tokens"] == 100 + 5 + 15
assert result["output_tokens"] == 20
assert result["total_tokens"] == 140
details = dict(result.get("input_token_details") or {})
assert details["cache_creation"] == 15
# Case 5: only one ephemeral key is non-zero
class CacheCreationPartial(BaseModel):
ephemeral_5m_input_tokens: int = 0
ephemeral_1h_input_tokens: int = 75
class UsageWithPartialCache(BaseModel):
input_tokens: int = 100
output_tokens: int = 10
cache_read_input_tokens: int = 0
cache_creation_input_tokens: int = 75
cache_creation: CacheCreationPartial = CacheCreationPartial()
result = _create_usage_metadata(UsageWithPartialCache())
# specific_cache_creation_tokens = 75 > 0, so generic cache_creation is suppressed
assert result["input_tokens"] == 100 + 0 + 75
assert result["output_tokens"] == 10
assert result["total_tokens"] == 185
details = dict(result.get("input_token_details") or {})
assert details["cache_creation"] == 0
assert details["ephemeral_1h_input_tokens"] == 75
# ephemeral_5m_input_tokens is 0 — still included since 0 is not None
assert details["ephemeral_5m_input_tokens"] == 0
# Case 6: no cache_creation field at all (the pre-existing path)
class UsageNoCacheCreation(BaseModel):
input_tokens: int = 50
output_tokens: int = 25
cache_read_input_tokens: int = 5
cache_creation_input_tokens: int = 10
result = _create_usage_metadata(UsageNoCacheCreation())
assert result["input_tokens"] == 50 + 5 + 10
assert result["output_tokens"] == 25
assert result["total_tokens"] == 90
details = dict(result.get("input_token_details") or {})
assert details["cache_read"] == 5
assert details["cache_creation"] == 10
class FakeTracer(BaseTracer):
"""Fake tracer to capture inputs to `chat_model_start`."""