From d3d23e2372b6cd0fc36d32625eaea1e8fcd64b59 Mon Sep 17 00:00:00 2001 From: Mason Daugherty Date: Fri, 15 Aug 2025 17:49:46 -0400 Subject: [PATCH] fix(anthropic): streaming token counting to defer input tokens until completion (#32518) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Supersedes #32461 Fixed incorrect input token reporting during streaming when tools are used. Previously, input tokens were counted at `message_start` before tool execution, leading to inaccurate counts. Now input tokens are properly deferred until `message_delta` (completion), aligning with Anthropic's billing model and SDK expectations. **Before Fix:** - Streaming with tools: Input tokens = 0 ❌ - Non-streaming with tools: Input tokens = 472 ✅ **After Fix:** - Streaming with tools: Input tokens = 472 ✅ - Non-streaming with tools: Input tokens = 472 ✅ Aligns with Anthropic's SDK expectations. The SDK handles input token updates in `message_delta` events: ```python # https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py if event.usage.input_tokens is not None: current_snapshot.usage.input_tokens = event.usage.input_tokens ``` --- .../langchain_anthropic/chat_models.py | 102 ++++++-- .../tests/unit_tests/test_chat_models.py | 225 +++++++++++++++++- 2 files changed, 310 insertions(+), 17 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 454c52d4d6c..410fee9a7dd 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -70,6 +70,20 @@ class AnthropicTool(TypedDict): cache_control: NotRequired[dict[str, str]] +class _CombinedUsage(BaseModel): + """Combined usage model for deferred token counting in streaming. + + This mimics the Anthropic Usage structure while combining stored input usage + with final output usage for accurate token reporting during streaming. + """ + + input_tokens: int = 0 + output_tokens: int = 0 + cache_creation_input_tokens: Optional[int] = None + cache_read_input_tokens: Optional[int] = None + cache_creation: Optional[dict[str, Any]] = None + + def _is_builtin_tool(tool: Any) -> bool: if not isinstance(tool, dict): return False @@ -1493,12 +1507,18 @@ class ChatAnthropic(BaseChatModel): and not _thinking_in_params(payload) ) block_start_event = None + stored_input_usage = None for event in stream: - msg, block_start_event = _make_message_chunk_from_anthropic_event( + ( + msg, + block_start_event, + stored_input_usage, + ) = _make_message_chunk_from_anthropic_event( event, stream_usage=stream_usage, coerce_content_to_string=coerce_content_to_string, block_start_event=block_start_event, + stored_input_usage=stored_input_usage, ) if msg is not None: chunk = ChatGenerationChunk(message=msg) @@ -1529,12 +1549,18 @@ class ChatAnthropic(BaseChatModel): and not _thinking_in_params(payload) ) block_start_event = None + stored_input_usage = None async for event in stream: - msg, block_start_event = _make_message_chunk_from_anthropic_event( + ( + msg, + block_start_event, + stored_input_usage, + ) = _make_message_chunk_from_anthropic_event( event, stream_usage=stream_usage, coerce_content_to_string=coerce_content_to_string, block_start_event=block_start_event, + stored_input_usage=stored_input_usage, ) if msg is not None: chunk = ChatGenerationChunk(message=msg) @@ -2167,22 +2193,40 @@ def _make_message_chunk_from_anthropic_event( stream_usage: bool = True, coerce_content_to_string: bool, block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None, -) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]: - """Convert Anthropic event to AIMessageChunk. + stored_input_usage: Optional[BaseModel] = None, +) -> tuple[ + Optional[AIMessageChunk], + Optional[anthropic.types.RawMessageStreamEvent], + Optional[BaseModel], +]: + """Convert Anthropic event to ``AIMessageChunk``. Note that not all events will result in a message chunk. In these cases we return ``None``. + + Args: + event: The Anthropic streaming event to convert. + stream_usage: Whether to include usage metadata in the chunk. + coerce_content_to_string: Whether to coerce content blocks to strings. + block_start_event: Previous content block start event for context. + stored_input_usage: Usage metadata from ``message_start`` event to be used + in ``message_delta`` event for accurate input token counts. + + Returns: + Tuple of ``(message_chunk, block_start_event, stored_usage)`` + """ message_chunk: Optional[AIMessageChunk] = None + updated_stored_usage = stored_input_usage # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 if event.type == "message_start" and stream_usage: - usage_metadata = _create_usage_metadata(event.message.usage) - # We pick up a cumulative count of output_tokens at the end of the stream, - # so here we zero out to avoid double counting. - usage_metadata["total_tokens"] = ( - usage_metadata["total_tokens"] - usage_metadata["output_tokens"] + # Store input usage for later use in message_delta but don't emit tokens yet + updated_stored_usage = event.message.usage + usage_metadata = UsageMetadata( + input_tokens=0, + output_tokens=0, + total_tokens=0, ) - usage_metadata["output_tokens"] = 0 if hasattr(event.message, "model"): response_metadata = {"model_name": event.message.model} else: @@ -2270,11 +2314,37 @@ def _make_message_chunk_from_anthropic_event( tool_call_chunks=tool_call_chunks, ) elif event.type == "message_delta" and stream_usage: - usage_metadata = UsageMetadata( - input_tokens=0, - output_tokens=event.usage.output_tokens, - total_tokens=event.usage.output_tokens, - ) + # Create usage metadata combining stored input usage with final output usage + # + # Per Anthropic docs: "The token counts shown in the usage field of the + # message_delta event are cumulative." Thus, when MCP tools are called + # mid-stream, `input_tokens` may be updated with a higher cumulative count. + # We prioritize `event.usage.input_tokens` when available to handle this case. + if stored_input_usage is not None: + # Create a combined usage object that mimics the Anthropic Usage structure + combined_usage = _CombinedUsage( + input_tokens=event.usage.input_tokens + or getattr(stored_input_usage, "input_tokens", 0), + output_tokens=event.usage.output_tokens, + cache_creation_input_tokens=getattr( + stored_input_usage, "cache_creation_input_tokens", None + ), + cache_read_input_tokens=getattr( + stored_input_usage, "cache_read_input_tokens", None + ), + cache_creation=getattr(stored_input_usage, "cache_creation", None) + if hasattr(stored_input_usage, "cache_creation") + else None, + ) + usage_metadata = _create_usage_metadata(combined_usage) + else: + # Fallback to just output tokens if no stored usage + usage_metadata = UsageMetadata( + input_tokens=event.usage.input_tokens or 0, + output_tokens=event.usage.output_tokens, + total_tokens=(event.usage.input_tokens or 0) + + event.usage.output_tokens, + ) message_chunk = AIMessageChunk( content="", usage_metadata=usage_metadata, @@ -2286,7 +2356,7 @@ def _make_message_chunk_from_anthropic_event( else: pass - return message_chunk, block_start_event + return message_chunk, block_start_event, updated_stored_usage @deprecated(since="0.1.0", removal="1.0.0", alternative="ChatAnthropic") diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 2a1b77e0b81..fee1125ce9e 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -3,12 +3,13 @@ from __future__ import annotations import os +from types import SimpleNamespace from typing import Any, Callable, Literal, Optional, cast from unittest.mock import MagicMock, patch import anthropic import pytest -from anthropic.types import Message, TextBlock, Usage +from anthropic.types import Message, MessageDeltaUsage, TextBlock, Usage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableBinding from langchain_core.tools import BaseTool @@ -22,6 +23,7 @@ from langchain_anthropic.chat_models import ( _create_usage_metadata, _format_image, _format_messages, + _make_message_chunk_from_anthropic_event, _merge_messages, convert_to_anthropic_tool, ) @@ -1172,3 +1174,224 @@ def test_cache_control_kwarg() -> None: ], }, ] + + +def test_streaming_token_counting_deferred() -> None: + """Test streaming defers input token counting until message completion. + + Validates that the streaming implementation correctly: + 1. Stores input tokens from `message_start` without emitting them immediately + 2. Combines stored input tokens with output tokens at `message_delta` completion + 3. Only emits complete token usage metadata when the message is finished + + This prevents the bug where tools would cause inaccurate token counts due to + premature emission of input tokens before tool execution completed. + """ + # Mock `message_start` event with usage + message_start_event = SimpleNamespace( + type="message_start", + message=SimpleNamespace( + usage=Usage( + input_tokens=100, + output_tokens=1, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + ), + model="claude-opus-4-1-20250805", + ), + ) + + # Mock `message_delta` event with final output tokens + message_delta_event = SimpleNamespace( + type="message_delta", + usage=MessageDeltaUsage( + output_tokens=50, + input_tokens=None, # This is None in real delta events + cache_creation_input_tokens=None, + cache_read_input_tokens=None, + ), + delta=SimpleNamespace( + stop_reason="end_turn", + stop_sequence=None, + ), + ) + + # Test `message_start` event - should store input tokens but not emit them + msg_chunk, _, stored_usage = _make_message_chunk_from_anthropic_event( + message_start_event, # type: ignore[arg-type] + stream_usage=True, + coerce_content_to_string=True, + stored_input_usage=None, + ) + + assert msg_chunk is not None + assert msg_chunk.usage_metadata is not None + + # Input tokens should be 0 at message_start (deferred) + assert msg_chunk.usage_metadata["input_tokens"] == 0 + assert msg_chunk.usage_metadata["output_tokens"] == 0 + assert msg_chunk.usage_metadata["total_tokens"] == 0 + + # Usage should be stored + assert stored_usage is not None + assert getattr(stored_usage, "input_tokens", 0) == 100 + + # Test `message_delta` - combine stored input with delta output tokens + msg_chunk, _, _ = _make_message_chunk_from_anthropic_event( + message_delta_event, # type: ignore[arg-type] + stream_usage=True, + coerce_content_to_string=True, + stored_input_usage=stored_usage, + ) + + assert msg_chunk is not None + assert msg_chunk.usage_metadata is not None + + # Should now have the complete usage metadata + assert msg_chunk.usage_metadata["input_tokens"] == 100 # From stored usage + assert msg_chunk.usage_metadata["output_tokens"] == 50 # From delta event + assert msg_chunk.usage_metadata["total_tokens"] == 150 + + # Verify response metadata is properly set + assert "stop_reason" in msg_chunk.response_metadata + assert msg_chunk.response_metadata["stop_reason"] == "end_turn" + + +def test_streaming_token_counting_fallback() -> None: + """Test streaming token counting gracefully handles missing stored usage. + + Validates that when no stored input usage is available (edge case scenario), + the streaming implementation safely falls back to reporting only output tokens + rather than failing or returning invalid token counts. + """ + # Mock message_delta event without stored input usage + message_delta_event = SimpleNamespace( + type="message_delta", + usage=MessageDeltaUsage( + output_tokens=25, + input_tokens=None, + cache_creation_input_tokens=None, + cache_read_input_tokens=None, + ), + delta=SimpleNamespace( + stop_reason="end_turn", + stop_sequence=None, + ), + ) + + # Test message_delta without stored usage - should fallback gracefully + msg_chunk, _, _ = _make_message_chunk_from_anthropic_event( + message_delta_event, # type: ignore[arg-type] + stream_usage=True, + coerce_content_to_string=True, + stored_input_usage=None, # No stored usage + ) + + assert msg_chunk is not None + assert msg_chunk.usage_metadata is not None + + # Should fallback to 0 input tokens and only report output tokens + assert msg_chunk.usage_metadata["input_tokens"] == 0 + assert msg_chunk.usage_metadata["output_tokens"] == 25 + assert msg_chunk.usage_metadata["total_tokens"] == 25 + + +def test_streaming_token_counting_cumulative_input_tokens() -> None: + """Test streaming handles cumulative input tokens from `message_delta` events. + + Validates that when Anthropic sends updated cumulative input tokens in + `message_delta` events (e.g., due to MCP tool calling), the implementation + prioritizes these updated counts over stored input usage. + + """ + # Mock `message_start` event with initial usage + message_start_event = SimpleNamespace( + type="message_start", + message=SimpleNamespace( + usage=Usage( + input_tokens=100, # Initial input tokens + output_tokens=1, + cache_creation_input_tokens=0, + cache_read_input_tokens=0, + ), + model="claude-opus-4-1-20250805", + ), + ) + + # Mock `message_delta` event with updated cumulative input tokens + # This happens when MCP tools are called mid-stream + message_delta_event = SimpleNamespace( + type="message_delta", + usage=MessageDeltaUsage( + output_tokens=50, + input_tokens=120, # Cumulative count increased due to tool calling + cache_creation_input_tokens=None, + cache_read_input_tokens=None, + ), + delta=SimpleNamespace( + stop_reason="end_turn", + stop_sequence=None, + ), + ) + + # Store input usage from `message_start` + _, _, stored_usage = _make_message_chunk_from_anthropic_event( + message_start_event, # type: ignore[arg-type] + stream_usage=True, + coerce_content_to_string=True, + stored_input_usage=None, + ) + + # Test `message_delta` with cumulative input tokens + msg_chunk, _, _ = _make_message_chunk_from_anthropic_event( + message_delta_event, # type: ignore[arg-type] + stream_usage=True, + coerce_content_to_string=True, + stored_input_usage=stored_usage, + ) + + assert msg_chunk is not None + assert msg_chunk.usage_metadata is not None + + # Should use the cumulative input tokens from event (120) not stored (100) + assert msg_chunk.usage_metadata["input_tokens"] == 120 + assert msg_chunk.usage_metadata["output_tokens"] == 50 + assert msg_chunk.usage_metadata["total_tokens"] == 170 + + +def test_streaming_token_counting_cumulative_fallback() -> None: + """Test fallback handles cumulative input tokens from message_delta events. + + When no stored usage is available, validates that cumulative input tokens + from the message_delta event are still properly used instead of defaulting to 0. + """ + # Mock `message_delta` event with cumulative input tokens but no stored usage + message_delta_event = SimpleNamespace( + type="message_delta", + usage=MessageDeltaUsage( + output_tokens=30, + input_tokens=85, # Cumulative input tokens in the event + cache_creation_input_tokens=None, + cache_read_input_tokens=None, + ), + delta=SimpleNamespace( + stop_reason="end_turn", + stop_sequence=None, + ), + ) + + # Test `message_delta` without stored usage - should use event's input tokens + msg_chunk, _, _ = _make_message_chunk_from_anthropic_event( + message_delta_event, # type: ignore[arg-type] + stream_usage=True, + coerce_content_to_string=True, + stored_input_usage=None, # No stored usage + ) + + assert msg_chunk is not None + assert msg_chunk.usage_metadata is not None + + # Should use cumulative input tokens from event, not fallback to 0 + assert msg_chunk.usage_metadata["input_tokens"] == 85 # From event + assert msg_chunk.usage_metadata["output_tokens"] == 30 + assert msg_chunk.usage_metadata["total_tokens"] == 115