fix(anthropic): streaming token counting to defer input tokens until completion (#32518)

Supersedes #32461

Fixed incorrect input token reporting during streaming when tools are
used. Previously, input tokens were counted at `message_start` before
tool execution, leading to inaccurate counts. Now input tokens are
properly deferred until `message_delta` (completion), aligning with
Anthropic's billing model and SDK expectations.

**Before Fix:**
- Streaming with tools: Input tokens = 0 
- Non-streaming with tools: Input tokens = 472 

**After Fix:**
- Streaming with tools: Input tokens = 472 
- Non-streaming with tools: Input tokens = 472 

Aligns with Anthropic's SDK expectations. The SDK handles input token
updates in `message_delta` events:

```python
# https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py
if event.usage.input_tokens is not None:
      current_snapshot.usage.input_tokens = event.usage.input_tokens
```
This commit is contained in:
Mason Daugherty 2025-08-15 17:49:46 -04:00 committed by GitHub
parent 2f32c444b8
commit d3d23e2372
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 310 additions and 17 deletions

View File

@ -70,6 +70,20 @@ class AnthropicTool(TypedDict):
cache_control: NotRequired[dict[str, str]]
class _CombinedUsage(BaseModel):
"""Combined usage model for deferred token counting in streaming.
This mimics the Anthropic Usage structure while combining stored input usage
with final output usage for accurate token reporting during streaming.
"""
input_tokens: int = 0
output_tokens: int = 0
cache_creation_input_tokens: Optional[int] = None
cache_read_input_tokens: Optional[int] = None
cache_creation: Optional[dict[str, Any]] = None
def _is_builtin_tool(tool: Any) -> bool:
if not isinstance(tool, dict):
return False
@ -1493,12 +1507,18 @@ class ChatAnthropic(BaseChatModel):
and not _thinking_in_params(payload)
)
block_start_event = None
stored_input_usage = None
for event in stream:
msg, block_start_event = _make_message_chunk_from_anthropic_event(
(
msg,
block_start_event,
stored_input_usage,
) = _make_message_chunk_from_anthropic_event(
event,
stream_usage=stream_usage,
coerce_content_to_string=coerce_content_to_string,
block_start_event=block_start_event,
stored_input_usage=stored_input_usage,
)
if msg is not None:
chunk = ChatGenerationChunk(message=msg)
@ -1529,12 +1549,18 @@ class ChatAnthropic(BaseChatModel):
and not _thinking_in_params(payload)
)
block_start_event = None
stored_input_usage = None
async for event in stream:
msg, block_start_event = _make_message_chunk_from_anthropic_event(
(
msg,
block_start_event,
stored_input_usage,
) = _make_message_chunk_from_anthropic_event(
event,
stream_usage=stream_usage,
coerce_content_to_string=coerce_content_to_string,
block_start_event=block_start_event,
stored_input_usage=stored_input_usage,
)
if msg is not None:
chunk = ChatGenerationChunk(message=msg)
@ -2167,22 +2193,40 @@ def _make_message_chunk_from_anthropic_event(
stream_usage: bool = True,
coerce_content_to_string: bool,
block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None,
) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]:
"""Convert Anthropic event to AIMessageChunk.
stored_input_usage: Optional[BaseModel] = None,
) -> tuple[
Optional[AIMessageChunk],
Optional[anthropic.types.RawMessageStreamEvent],
Optional[BaseModel],
]:
"""Convert Anthropic event to ``AIMessageChunk``.
Note that not all events will result in a message chunk. In these cases
we return ``None``.
Args:
event: The Anthropic streaming event to convert.
stream_usage: Whether to include usage metadata in the chunk.
coerce_content_to_string: Whether to coerce content blocks to strings.
block_start_event: Previous content block start event for context.
stored_input_usage: Usage metadata from ``message_start`` event to be used
in ``message_delta`` event for accurate input token counts.
Returns:
Tuple of ``(message_chunk, block_start_event, stored_usage)``
"""
message_chunk: Optional[AIMessageChunk] = None
updated_stored_usage = stored_input_usage
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage:
usage_metadata = _create_usage_metadata(event.message.usage)
# We pick up a cumulative count of output_tokens at the end of the stream,
# so here we zero out to avoid double counting.
usage_metadata["total_tokens"] = (
usage_metadata["total_tokens"] - usage_metadata["output_tokens"]
# Store input usage for later use in message_delta but don't emit tokens yet
updated_stored_usage = event.message.usage
usage_metadata = UsageMetadata(
input_tokens=0,
output_tokens=0,
total_tokens=0,
)
usage_metadata["output_tokens"] = 0
if hasattr(event.message, "model"):
response_metadata = {"model_name": event.message.model}
else:
@ -2270,10 +2314,36 @@ def _make_message_chunk_from_anthropic_event(
tool_call_chunks=tool_call_chunks,
)
elif event.type == "message_delta" and stream_usage:
usage_metadata = UsageMetadata(
input_tokens=0,
# Create usage metadata combining stored input usage with final output usage
#
# Per Anthropic docs: "The token counts shown in the usage field of the
# message_delta event are cumulative." Thus, when MCP tools are called
# mid-stream, `input_tokens` may be updated with a higher cumulative count.
# We prioritize `event.usage.input_tokens` when available to handle this case.
if stored_input_usage is not None:
# Create a combined usage object that mimics the Anthropic Usage structure
combined_usage = _CombinedUsage(
input_tokens=event.usage.input_tokens
or getattr(stored_input_usage, "input_tokens", 0),
output_tokens=event.usage.output_tokens,
total_tokens=event.usage.output_tokens,
cache_creation_input_tokens=getattr(
stored_input_usage, "cache_creation_input_tokens", None
),
cache_read_input_tokens=getattr(
stored_input_usage, "cache_read_input_tokens", None
),
cache_creation=getattr(stored_input_usage, "cache_creation", None)
if hasattr(stored_input_usage, "cache_creation")
else None,
)
usage_metadata = _create_usage_metadata(combined_usage)
else:
# Fallback to just output tokens if no stored usage
usage_metadata = UsageMetadata(
input_tokens=event.usage.input_tokens or 0,
output_tokens=event.usage.output_tokens,
total_tokens=(event.usage.input_tokens or 0)
+ event.usage.output_tokens,
)
message_chunk = AIMessageChunk(
content="",
@ -2286,7 +2356,7 @@ def _make_message_chunk_from_anthropic_event(
else:
pass
return message_chunk, block_start_event
return message_chunk, block_start_event, updated_stored_usage
@deprecated(since="0.1.0", removal="1.0.0", alternative="ChatAnthropic")

View File

@ -3,12 +3,13 @@
from __future__ import annotations
import os
from types import SimpleNamespace
from typing import Any, Callable, Literal, Optional, cast
from unittest.mock import MagicMock, patch
import anthropic
import pytest
from anthropic.types import Message, TextBlock, Usage
from anthropic.types import Message, MessageDeltaUsage, TextBlock, Usage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import BaseTool
@ -22,6 +23,7 @@ from langchain_anthropic.chat_models import (
_create_usage_metadata,
_format_image,
_format_messages,
_make_message_chunk_from_anthropic_event,
_merge_messages,
convert_to_anthropic_tool,
)
@ -1172,3 +1174,224 @@ def test_cache_control_kwarg() -> None:
],
},
]
def test_streaming_token_counting_deferred() -> None:
"""Test streaming defers input token counting until message completion.
Validates that the streaming implementation correctly:
1. Stores input tokens from `message_start` without emitting them immediately
2. Combines stored input tokens with output tokens at `message_delta` completion
3. Only emits complete token usage metadata when the message is finished
This prevents the bug where tools would cause inaccurate token counts due to
premature emission of input tokens before tool execution completed.
"""
# Mock `message_start` event with usage
message_start_event = SimpleNamespace(
type="message_start",
message=SimpleNamespace(
usage=Usage(
input_tokens=100,
output_tokens=1,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
),
model="claude-opus-4-1-20250805",
),
)
# Mock `message_delta` event with final output tokens
message_delta_event = SimpleNamespace(
type="message_delta",
usage=MessageDeltaUsage(
output_tokens=50,
input_tokens=None, # This is None in real delta events
cache_creation_input_tokens=None,
cache_read_input_tokens=None,
),
delta=SimpleNamespace(
stop_reason="end_turn",
stop_sequence=None,
),
)
# Test `message_start` event - should store input tokens but not emit them
msg_chunk, _, stored_usage = _make_message_chunk_from_anthropic_event(
message_start_event, # type: ignore[arg-type]
stream_usage=True,
coerce_content_to_string=True,
stored_input_usage=None,
)
assert msg_chunk is not None
assert msg_chunk.usage_metadata is not None
# Input tokens should be 0 at message_start (deferred)
assert msg_chunk.usage_metadata["input_tokens"] == 0
assert msg_chunk.usage_metadata["output_tokens"] == 0
assert msg_chunk.usage_metadata["total_tokens"] == 0
# Usage should be stored
assert stored_usage is not None
assert getattr(stored_usage, "input_tokens", 0) == 100
# Test `message_delta` - combine stored input with delta output tokens
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
message_delta_event, # type: ignore[arg-type]
stream_usage=True,
coerce_content_to_string=True,
stored_input_usage=stored_usage,
)
assert msg_chunk is not None
assert msg_chunk.usage_metadata is not None
# Should now have the complete usage metadata
assert msg_chunk.usage_metadata["input_tokens"] == 100 # From stored usage
assert msg_chunk.usage_metadata["output_tokens"] == 50 # From delta event
assert msg_chunk.usage_metadata["total_tokens"] == 150
# Verify response metadata is properly set
assert "stop_reason" in msg_chunk.response_metadata
assert msg_chunk.response_metadata["stop_reason"] == "end_turn"
def test_streaming_token_counting_fallback() -> None:
"""Test streaming token counting gracefully handles missing stored usage.
Validates that when no stored input usage is available (edge case scenario),
the streaming implementation safely falls back to reporting only output tokens
rather than failing or returning invalid token counts.
"""
# Mock message_delta event without stored input usage
message_delta_event = SimpleNamespace(
type="message_delta",
usage=MessageDeltaUsage(
output_tokens=25,
input_tokens=None,
cache_creation_input_tokens=None,
cache_read_input_tokens=None,
),
delta=SimpleNamespace(
stop_reason="end_turn",
stop_sequence=None,
),
)
# Test message_delta without stored usage - should fallback gracefully
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
message_delta_event, # type: ignore[arg-type]
stream_usage=True,
coerce_content_to_string=True,
stored_input_usage=None, # No stored usage
)
assert msg_chunk is not None
assert msg_chunk.usage_metadata is not None
# Should fallback to 0 input tokens and only report output tokens
assert msg_chunk.usage_metadata["input_tokens"] == 0
assert msg_chunk.usage_metadata["output_tokens"] == 25
assert msg_chunk.usage_metadata["total_tokens"] == 25
def test_streaming_token_counting_cumulative_input_tokens() -> None:
"""Test streaming handles cumulative input tokens from `message_delta` events.
Validates that when Anthropic sends updated cumulative input tokens in
`message_delta` events (e.g., due to MCP tool calling), the implementation
prioritizes these updated counts over stored input usage.
"""
# Mock `message_start` event with initial usage
message_start_event = SimpleNamespace(
type="message_start",
message=SimpleNamespace(
usage=Usage(
input_tokens=100, # Initial input tokens
output_tokens=1,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
),
model="claude-opus-4-1-20250805",
),
)
# Mock `message_delta` event with updated cumulative input tokens
# This happens when MCP tools are called mid-stream
message_delta_event = SimpleNamespace(
type="message_delta",
usage=MessageDeltaUsage(
output_tokens=50,
input_tokens=120, # Cumulative count increased due to tool calling
cache_creation_input_tokens=None,
cache_read_input_tokens=None,
),
delta=SimpleNamespace(
stop_reason="end_turn",
stop_sequence=None,
),
)
# Store input usage from `message_start`
_, _, stored_usage = _make_message_chunk_from_anthropic_event(
message_start_event, # type: ignore[arg-type]
stream_usage=True,
coerce_content_to_string=True,
stored_input_usage=None,
)
# Test `message_delta` with cumulative input tokens
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
message_delta_event, # type: ignore[arg-type]
stream_usage=True,
coerce_content_to_string=True,
stored_input_usage=stored_usage,
)
assert msg_chunk is not None
assert msg_chunk.usage_metadata is not None
# Should use the cumulative input tokens from event (120) not stored (100)
assert msg_chunk.usage_metadata["input_tokens"] == 120
assert msg_chunk.usage_metadata["output_tokens"] == 50
assert msg_chunk.usage_metadata["total_tokens"] == 170
def test_streaming_token_counting_cumulative_fallback() -> None:
"""Test fallback handles cumulative input tokens from message_delta events.
When no stored usage is available, validates that cumulative input tokens
from the message_delta event are still properly used instead of defaulting to 0.
"""
# Mock `message_delta` event with cumulative input tokens but no stored usage
message_delta_event = SimpleNamespace(
type="message_delta",
usage=MessageDeltaUsage(
output_tokens=30,
input_tokens=85, # Cumulative input tokens in the event
cache_creation_input_tokens=None,
cache_read_input_tokens=None,
),
delta=SimpleNamespace(
stop_reason="end_turn",
stop_sequence=None,
),
)
# Test `message_delta` without stored usage - should use event's input tokens
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
message_delta_event, # type: ignore[arg-type]
stream_usage=True,
coerce_content_to_string=True,
stored_input_usage=None, # No stored usage
)
assert msg_chunk is not None
assert msg_chunk.usage_metadata is not None
# Should use cumulative input tokens from event, not fallback to 0
assert msg_chunk.usage_metadata["input_tokens"] == 85 # From event
assert msg_chunk.usage_metadata["output_tokens"] == 30
assert msg_chunk.usage_metadata["total_tokens"] == 115