mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
anthropic[patch]: fix input_tokens when cached (#27125)
This commit is contained in:
parent
64a16f2cf0
commit
0b8416bd2e
@ -1253,7 +1253,12 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
|
||||
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
|
||||
}
|
||||
|
||||
input_tokens = getattr(anthropic_usage, "input_tokens", 0)
|
||||
# Anthropic input_tokens exclude cached token counts.
|
||||
input_tokens = (
|
||||
getattr(anthropic_usage, "input_tokens", 0)
|
||||
+ (input_token_details["cache_read"] or 0)
|
||||
+ (input_token_details["cache_creation"] or 0)
|
||||
)
|
||||
output_tokens = getattr(anthropic_usage, "output_tokens", 0)
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
|
@ -128,9 +128,9 @@ def test__format_output_cached() -> None:
|
||||
expected = AIMessage( # type: ignore[misc]
|
||||
"bar",
|
||||
usage_metadata={
|
||||
"input_tokens": 2,
|
||||
"input_tokens": 9,
|
||||
"output_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
"total_tokens": 10,
|
||||
"input_token_details": {"cache_creation": 3, "cache_read": 4},
|
||||
},
|
||||
)
|
||||
|
@ -153,28 +153,58 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
|
||||
if "audio_input" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_audio_input()
|
||||
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index]
|
||||
assert msg.usage_metadata is not None
|
||||
assert msg.usage_metadata["input_token_details"] is not None
|
||||
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int)
|
||||
assert msg.usage_metadata["input_tokens"] >= sum(
|
||||
(v or 0) # type: ignore[misc]
|
||||
for v in msg.usage_metadata["input_token_details"].values()
|
||||
)
|
||||
if "audio_output" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_audio_output()
|
||||
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int) # type: ignore[index]
|
||||
assert msg.usage_metadata is not None
|
||||
assert msg.usage_metadata["output_token_details"] is not None
|
||||
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int)
|
||||
assert int(msg.usage_metadata["output_tokens"]) >= sum(
|
||||
(v or 0) # type: ignore[misc]
|
||||
for v in msg.usage_metadata["output_token_details"].values()
|
||||
)
|
||||
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_reasoning_output()
|
||||
assert msg.usage_metadata is not None
|
||||
assert msg.usage_metadata["output_token_details"] is not None
|
||||
assert isinstance(
|
||||
msg.usage_metadata["output_token_details"]["reasoning"], # type: ignore[index]
|
||||
msg.usage_metadata["output_token_details"]["reasoning"],
|
||||
int,
|
||||
)
|
||||
assert msg.usage_metadata["output_tokens"] >= sum(
|
||||
(v or 0) # type: ignore[misc]
|
||||
for v in msg.usage_metadata["output_token_details"].values()
|
||||
)
|
||||
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_cache_read_input()
|
||||
assert msg.usage_metadata is not None
|
||||
assert msg.usage_metadata["input_token_details"] is not None
|
||||
assert isinstance(
|
||||
msg.usage_metadata["input_token_details"]["cache_read"], # type: ignore[index]
|
||||
msg.usage_metadata["input_token_details"]["cache_read"],
|
||||
int,
|
||||
)
|
||||
assert msg.usage_metadata["input_tokens"] >= sum(
|
||||
(v or 0) # type: ignore[misc]
|
||||
for v in msg.usage_metadata["input_token_details"].values()
|
||||
)
|
||||
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
|
||||
msg = self.invoke_with_cache_creation_input()
|
||||
assert msg.usage_metadata is not None
|
||||
assert msg.usage_metadata["input_token_details"] is not None
|
||||
assert isinstance(
|
||||
msg.usage_metadata["input_token_details"]["cache_creation"], # type: ignore[index]
|
||||
msg.usage_metadata["input_token_details"]["cache_creation"],
|
||||
int,
|
||||
)
|
||||
assert msg.usage_metadata["input_tokens"] >= sum(
|
||||
(v or 0) # type: ignore[misc]
|
||||
for v in msg.usage_metadata["input_token_details"].values()
|
||||
)
|
||||
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
if not self.returns_usage_metadata:
|
||||
|
Loading…
Reference in New Issue
Block a user