anthropic[patch]: fix input_tokens when cached (#27125)

This commit is contained in:
Bagatur 2024-10-04 15:35:51 -07:00 committed by GitHub
parent 64a16f2cf0
commit 0b8416bd2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 43 additions and 8 deletions

View File

@ -1253,7 +1253,12 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None), "cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
} }
input_tokens = getattr(anthropic_usage, "input_tokens", 0) # Anthropic input_tokens exclude cached token counts.
input_tokens = (
getattr(anthropic_usage, "input_tokens", 0)
+ (input_token_details["cache_read"] or 0)
+ (input_token_details["cache_creation"] or 0)
)
output_tokens = getattr(anthropic_usage, "output_tokens", 0) output_tokens = getattr(anthropic_usage, "output_tokens", 0)
return UsageMetadata( return UsageMetadata(
input_tokens=input_tokens, input_tokens=input_tokens,

View File

@ -128,9 +128,9 @@ def test__format_output_cached() -> None:
expected = AIMessage( # type: ignore[misc] expected = AIMessage( # type: ignore[misc]
"bar", "bar",
usage_metadata={ usage_metadata={
"input_tokens": 2, "input_tokens": 9,
"output_tokens": 1, "output_tokens": 1,
"total_tokens": 3, "total_tokens": 10,
"input_token_details": {"cache_creation": 3, "cache_read": 4}, "input_token_details": {"cache_creation": 3, "cache_read": 4},
}, },
) )

View File

@ -153,28 +153,58 @@ class ChatModelIntegrationTests(ChatModelTests):
if "audio_input" in self.supported_usage_metadata_details["invoke"]: if "audio_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_input() msg = self.invoke_with_audio_input()
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index] assert msg.usage_metadata is not None
assert msg.usage_metadata["input_token_details"] is not None
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int)
assert msg.usage_metadata["input_tokens"] >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["input_token_details"].values()
)
if "audio_output" in self.supported_usage_metadata_details["invoke"]: if "audio_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_output() msg = self.invoke_with_audio_output()
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int) # type: ignore[index] assert msg.usage_metadata is not None
assert msg.usage_metadata["output_token_details"] is not None
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int)
assert int(msg.usage_metadata["output_tokens"]) >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["output_token_details"].values()
)
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]: if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_reasoning_output() msg = self.invoke_with_reasoning_output()
assert msg.usage_metadata is not None
assert msg.usage_metadata["output_token_details"] is not None
assert isinstance( assert isinstance(
msg.usage_metadata["output_token_details"]["reasoning"], # type: ignore[index] msg.usage_metadata["output_token_details"]["reasoning"],
int, int,
) )
assert msg.usage_metadata["output_tokens"] >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["output_token_details"].values()
)
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]: if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_read_input() msg = self.invoke_with_cache_read_input()
assert msg.usage_metadata is not None
assert msg.usage_metadata["input_token_details"] is not None
assert isinstance( assert isinstance(
msg.usage_metadata["input_token_details"]["cache_read"], # type: ignore[index] msg.usage_metadata["input_token_details"]["cache_read"],
int, int,
) )
assert msg.usage_metadata["input_tokens"] >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["input_token_details"].values()
)
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]: if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_creation_input() msg = self.invoke_with_cache_creation_input()
assert msg.usage_metadata is not None
assert msg.usage_metadata["input_token_details"] is not None
assert isinstance( assert isinstance(
msg.usage_metadata["input_token_details"]["cache_creation"], # type: ignore[index] msg.usage_metadata["input_token_details"]["cache_creation"],
int, int,
) )
assert msg.usage_metadata["input_tokens"] >= sum(
(v or 0) # type: ignore[misc]
for v in msg.usage_metadata["input_token_details"].values()
)
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
if not self.returns_usage_metadata: if not self.returns_usage_metadata: