anthropic[patch]: be robust to null fields when translating usage metadata (#31151)

This commit is contained in:
ccurme 2025-05-07 14:30:21 -04:00 committed by GitHub
parent f70b263ff3
commit b5b90b5929
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 3 deletions

View File

@ -1844,11 +1844,11 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
# Anthropic input_tokens exclude cached token counts. # Anthropic input_tokens exclude cached token counts.
input_tokens = ( input_tokens = (
getattr(anthropic_usage, "input_tokens", 0) (getattr(anthropic_usage, "input_tokens", 0) or 0)
+ (input_token_details["cache_read"] or 0) + (input_token_details["cache_read"] or 0)
+ (input_token_details["cache_creation"] or 0) + (input_token_details["cache_creation"] or 0)
) )
output_tokens = getattr(anthropic_usage, "output_tokens", 0) output_tokens = getattr(anthropic_usage, "output_tokens", 0) or 0
return UsageMetadata( return UsageMetadata(
input_tokens=input_tokens, input_tokens=input_tokens,
output_tokens=output_tokens, output_tokens=output_tokens,

View File

@ -1,7 +1,7 @@
"""Test chat model integration.""" """Test chat model integration."""
import os import os
from typing import Any, Callable, Literal, cast from typing import Any, Callable, Literal, Optional, cast
from unittest.mock import patch from unittest.mock import patch
import anthropic import anthropic
@ -15,6 +15,7 @@ from pytest import CaptureFixture, MonkeyPatch
from langchain_anthropic import ChatAnthropic from langchain_anthropic import ChatAnthropic
from langchain_anthropic.chat_models import ( from langchain_anthropic.chat_models import (
_create_usage_metadata,
_format_image, _format_image,
_format_messages, _format_messages,
_merge_messages, _merge_messages,
@ -954,3 +955,42 @@ def test_get_num_tokens_from_messages_passes_kwargs() -> None:
assert ( assert (
_Client.return_value.beta.messages.count_tokens.call_args.kwargs["foo"] == "bar" _Client.return_value.beta.messages.count_tokens.call_args.kwargs["foo"] == "bar"
) )
def test_usage_metadata_standardization() -> None:
class UsageModel(BaseModel):
input_tokens: int = 10
output_tokens: int = 5
cache_read_input_tokens: int = 3
cache_creation_input_tokens: int = 2
# Happy path
usage = UsageModel()
result = _create_usage_metadata(usage)
assert result["input_tokens"] == 15 # 10 + 3 + 2
assert result["output_tokens"] == 5
assert result["total_tokens"] == 20
assert result["input_token_details"] == {"cache_read": 3, "cache_creation": 2}
# Null input and output tokens
class UsageModelNulls(BaseModel):
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
cache_creation_input_tokens: Optional[int] = None
usage_nulls = UsageModelNulls()
result = _create_usage_metadata(usage_nulls)
assert result["input_tokens"] == 0
assert result["output_tokens"] == 0
assert result["total_tokens"] == 0
# Test missing fields
class UsageModelMissing(BaseModel):
pass
usage_missing = UsageModelMissing()
result = _create_usage_metadata(usage_missing)
assert result["input_tokens"] == 0
assert result["output_tokens"] == 0
assert result["total_tokens"] == 0