anthropic: update streaming usage metadata (#31158)

Anthropic updated how they report token counts during streaming today.
See changes to `MessageDeltaUsage` in [this
commit](2da00f26c5 (diff-1a396eba0cd9cd8952dcdb58049d3b13f6b7768ead1411888d66e28211f7bfc5)).

It's clean and simple to grab these fields from the final
`message_delta` event. However, some of them are typed as Optional, and
language
[here](e42451ab3f/src/anthropic/lib/streaming/_messages.py (L462))
suggests they may not always be present. So here we take the required
field from the `message_delta` event as we were doing previously, and
ignore the rest.
This commit is contained in:
ccurme 2025-05-07 23:09:56 -04:00 committed by GitHub
parent 6c3901f9f9
commit e34f9fd6f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 3 deletions

View File

@ -1744,6 +1744,12 @@ def _make_message_chunk_from_anthropic_event(
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage:
usage_metadata = _create_usage_metadata(event.message.usage)
# We pick up a cumulative count of output_tokens at the end of the stream,
# so here we zero out to avoid double counting.
usage_metadata["total_tokens"] = (
usage_metadata["total_tokens"] - usage_metadata["output_tokens"]
)
usage_metadata["output_tokens"] = 0
if hasattr(event.message, "model"):
response_metadata = {"model_name": event.message.model}
else:
@ -1817,7 +1823,11 @@ def _make_message_chunk_from_anthropic_event(
tool_call_chunks=[tool_call_chunk], # type: ignore
)
elif event.type == "message_delta" and stream_usage:
usage_metadata = _create_usage_metadata(event.usage)
usage_metadata = UsageMetadata(
input_tokens=0,
output_tokens=event.usage.output_tokens,
total_tokens=event.usage.output_tokens,
)
message_chunk = AIMessageChunk(
content="",
usage_metadata=usage_metadata,

View File

@ -46,7 +46,7 @@ def test_stream() -> None:
if token.usage_metadata is not None:
if token.usage_metadata.get("input_tokens"):
chunks_with_input_token_counts += 1
elif token.usage_metadata.get("output_tokens"):
if token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1
chunks_with_model_name += int("model_name" in token.response_metadata)
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
@ -85,7 +85,7 @@ async def test_astream() -> None:
if token.usage_metadata is not None:
if token.usage_metadata.get("input_tokens"):
chunks_with_input_token_counts += 1
elif token.usage_metadata.get("output_tokens"):
if token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError(