diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 57b8e01792b..7c05c0198df 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -2181,18 +2181,27 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata: # Currently just copying over the 5m and 1h keys, but if more are added in the # future we'll need to expand this tuple cache_creation_keys = ("ephemeral_5m_input_tokens", "ephemeral_1h_input_tokens") + specific_cache_creation_tokens = 0 if cache_creation: if isinstance(cache_creation, BaseModel): cache_creation = cache_creation.model_dump() for k in cache_creation_keys: + specific_cache_creation_tokens += cache_creation.get(k, 0) input_token_details[k] = cache_creation.get(k) + if not isinstance(specific_cache_creation_tokens, int): + specific_cache_creation_tokens = 0 + if specific_cache_creation_tokens > 0: + # Remove generic key to avoid double counting cache creation tokens + input_token_details["cache_creation"] = 0 # Calculate total input tokens: Anthropic's `input_tokens` excludes cached tokens, # so we need to add them back to get the true total input token count input_tokens = ( (getattr(anthropic_usage, "input_tokens", 0) or 0) # Base input tokens + (input_token_details["cache_read"] or 0) # Tokens read from cache - + (input_token_details["cache_creation"] or 0) # Tokens used to create cache + + ( + specific_cache_creation_tokens or input_token_details["cache_creation"] or 0 + ) # Tokens used to create cache ) output_tokens = getattr(anthropic_usage, "output_tokens", 0) or 0 diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 02862bdef74..00ef3823627 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -1519,6 +1519,132 @@ def test_usage_metadata_standardization() -> None: assert result["total_tokens"] == 0 +def test_usage_metadata_cache_creation_ttl() -> None: + """Test _create_usage_metadata with granular cache_creation TTL fields.""" + + # Case 1: cache_creation with specific ephemeral TTL tokens (BaseModel) + class CacheCreation(BaseModel): + ephemeral_5m_input_tokens: int = 100 + ephemeral_1h_input_tokens: int = 50 + + class UsageWithCacheCreation(BaseModel): + input_tokens: int = 200 + output_tokens: int = 30 + cache_read_input_tokens: int = 10 + cache_creation_input_tokens: int = 150 + cache_creation: CacheCreation = CacheCreation() + + result = _create_usage_metadata(UsageWithCacheCreation()) + # input_tokens = 200 (base) + 10 (cache_read) + 150 (specific: 100+50) + assert result["input_tokens"] == 360 + assert result["output_tokens"] == 30 + assert result["total_tokens"] == 390 + details = dict(result.get("input_token_details") or {}) + assert details["cache_read"] == 10 + # cache_creation should be suppressed to avoid double counting + assert details["cache_creation"] == 0 + assert details["ephemeral_5m_input_tokens"] == 100 + assert details["ephemeral_1h_input_tokens"] == 50 + + # Case 2: cache_creation as a dict + class UsageWithCacheCreationDict(BaseModel): + input_tokens: int = 200 + output_tokens: int = 30 + cache_read_input_tokens: int = 10 + cache_creation_input_tokens: int = 150 + cache_creation: dict = { + "ephemeral_5m_input_tokens": 80, + "ephemeral_1h_input_tokens": 70, + } + + result = _create_usage_metadata(UsageWithCacheCreationDict()) + assert result["input_tokens"] == 200 + 10 + 80 + 70 + details = dict(result.get("input_token_details") or {}) + assert details["cache_creation"] == 0 + assert details["ephemeral_5m_input_tokens"] == 80 + assert details["ephemeral_1h_input_tokens"] == 70 + + # Case 3: cache_creation exists but specific keys are zero — falls back to + # generic cache_creation_input_tokens + class CacheCreationZero(BaseModel): + ephemeral_5m_input_tokens: int = 0 + ephemeral_1h_input_tokens: int = 0 + + class UsageWithCacheCreationZero(BaseModel): + input_tokens: int = 200 + output_tokens: int = 30 + cache_read_input_tokens: int = 10 + cache_creation_input_tokens: int = 50 + cache_creation: CacheCreationZero = CacheCreationZero() + + result = _create_usage_metadata(UsageWithCacheCreationZero()) + # specific_cache_creation_tokens = 0, so falls back to cache_creation_input_tokens + # input_tokens = 200 + 10 + 50 = 260 + assert result["input_tokens"] == 260 + assert result["output_tokens"] == 30 + assert result["total_tokens"] == 290 + details = dict(result.get("input_token_details") or {}) + assert details["cache_read"] == 10 + assert details["cache_creation"] == 50 + + # Case 4: cache_creation exists but specific keys are missing from the dict + class CacheCreationEmpty(BaseModel): + pass + + class UsageWithCacheCreationEmpty(BaseModel): + input_tokens: int = 100 + output_tokens: int = 20 + cache_read_input_tokens: int = 5 + cache_creation_input_tokens: int = 15 + cache_creation: CacheCreationEmpty = CacheCreationEmpty() + + result = _create_usage_metadata(UsageWithCacheCreationEmpty()) + # specific_cache_creation_tokens = 0, falls back to cache_creation_input_tokens + assert result["input_tokens"] == 100 + 5 + 15 + assert result["output_tokens"] == 20 + assert result["total_tokens"] == 140 + details = dict(result.get("input_token_details") or {}) + assert details["cache_creation"] == 15 + + # Case 5: only one ephemeral key is non-zero + class CacheCreationPartial(BaseModel): + ephemeral_5m_input_tokens: int = 0 + ephemeral_1h_input_tokens: int = 75 + + class UsageWithPartialCache(BaseModel): + input_tokens: int = 100 + output_tokens: int = 10 + cache_read_input_tokens: int = 0 + cache_creation_input_tokens: int = 75 + cache_creation: CacheCreationPartial = CacheCreationPartial() + + result = _create_usage_metadata(UsageWithPartialCache()) + # specific_cache_creation_tokens = 75 > 0, so generic cache_creation is suppressed + assert result["input_tokens"] == 100 + 0 + 75 + assert result["output_tokens"] == 10 + assert result["total_tokens"] == 185 + details = dict(result.get("input_token_details") or {}) + assert details["cache_creation"] == 0 + assert details["ephemeral_1h_input_tokens"] == 75 + # ephemeral_5m_input_tokens is 0 — still included since 0 is not None + assert details["ephemeral_5m_input_tokens"] == 0 + + # Case 6: no cache_creation field at all (the pre-existing path) + class UsageNoCacheCreation(BaseModel): + input_tokens: int = 50 + output_tokens: int = 25 + cache_read_input_tokens: int = 5 + cache_creation_input_tokens: int = 10 + + result = _create_usage_metadata(UsageNoCacheCreation()) + assert result["input_tokens"] == 50 + 5 + 10 + assert result["output_tokens"] == 25 + assert result["total_tokens"] == 90 + details = dict(result.get("input_token_details") or {}) + assert details["cache_read"] == 5 + assert details["cache_creation"] == 10 + + class FakeTracer(BaseTracer): """Fake tracer to capture inputs to `chat_model_start`."""