diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index d1f0237b9b9..122f523e9e3 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1844,11 +1844,11 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata: # Anthropic input_tokens exclude cached token counts. input_tokens = ( - getattr(anthropic_usage, "input_tokens", 0) + (getattr(anthropic_usage, "input_tokens", 0) or 0) + (input_token_details["cache_read"] or 0) + (input_token_details["cache_creation"] or 0) ) - output_tokens = getattr(anthropic_usage, "output_tokens", 0) + output_tokens = getattr(anthropic_usage, "output_tokens", 0) or 0 return UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 4d50794b00b..1c18fc08d94 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -1,7 +1,7 @@ """Test chat model integration.""" import os -from typing import Any, Callable, Literal, cast +from typing import Any, Callable, Literal, Optional, cast from unittest.mock import patch import anthropic @@ -15,6 +15,7 @@ from pytest import CaptureFixture, MonkeyPatch from langchain_anthropic import ChatAnthropic from langchain_anthropic.chat_models import ( + _create_usage_metadata, _format_image, _format_messages, _merge_messages, @@ -954,3 +955,42 @@ def test_get_num_tokens_from_messages_passes_kwargs() -> None: assert ( _Client.return_value.beta.messages.count_tokens.call_args.kwargs["foo"] == "bar" ) + + +def test_usage_metadata_standardization() -> None: + class UsageModel(BaseModel): + input_tokens: int = 10 + output_tokens: int = 5 + cache_read_input_tokens: int = 3 + cache_creation_input_tokens: int = 2 + + # Happy path + usage = UsageModel() + result = _create_usage_metadata(usage) + assert result["input_tokens"] == 15 # 10 + 3 + 2 + assert result["output_tokens"] == 5 + assert result["total_tokens"] == 20 + assert result["input_token_details"] == {"cache_read": 3, "cache_creation": 2} + + # Null input and output tokens + class UsageModelNulls(BaseModel): + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + cache_read_input_tokens: Optional[int] = None + cache_creation_input_tokens: Optional[int] = None + + usage_nulls = UsageModelNulls() + result = _create_usage_metadata(usage_nulls) + assert result["input_tokens"] == 0 + assert result["output_tokens"] == 0 + assert result["total_tokens"] == 0 + + # Test missing fields + class UsageModelMissing(BaseModel): + pass + + usage_missing = UsageModelMissing() + result = _create_usage_metadata(usage_missing) + assert result["input_tokens"] == 0 + assert result["output_tokens"] == 0 + assert result["total_tokens"] == 0