diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 7629b0b5c63..2ec8bfe916d 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -70,20 +70,6 @@ class AnthropicTool(TypedDict): cache_control: NotRequired[dict[str, str]] -class _CombinedUsage(BaseModel): - """Combined usage model for deferred token counting in streaming. - - This mimics the Anthropic Usage structure while combining stored input usage - with final output usage for accurate token reporting during streaming. - """ - - input_tokens: int = 0 - output_tokens: int = 0 - cache_creation_input_tokens: Optional[int] = None - cache_read_input_tokens: Optional[int] = None - cache_creation: Optional[dict[str, Any]] = None - - def _is_builtin_tool(tool: Any) -> bool: if not isinstance(tool, dict): return False @@ -1522,18 +1508,12 @@ class ChatAnthropic(BaseChatModel): and not _thinking_in_params(payload) ) block_start_event = None - stored_input_usage = None for event in stream: - ( - msg, - block_start_event, - stored_input_usage, - ) = _make_message_chunk_from_anthropic_event( + msg, block_start_event = _make_message_chunk_from_anthropic_event( event, stream_usage=stream_usage, coerce_content_to_string=coerce_content_to_string, block_start_event=block_start_event, - stored_input_usage=stored_input_usage, ) if msg is not None: chunk = ChatGenerationChunk(message=msg) @@ -1564,18 +1544,12 @@ class ChatAnthropic(BaseChatModel): and not _thinking_in_params(payload) ) block_start_event = None - stored_input_usage = None async for event in stream: - ( - msg, - block_start_event, - stored_input_usage, - ) = _make_message_chunk_from_anthropic_event( + msg, block_start_event = _make_message_chunk_from_anthropic_event( event, stream_usage=stream_usage, coerce_content_to_string=coerce_content_to_string, block_start_event=block_start_event, - stored_input_usage=stored_input_usage, ) if msg is not None: chunk = ChatGenerationChunk(message=msg) @@ -2208,40 +2182,22 @@ def _make_message_chunk_from_anthropic_event( stream_usage: bool = True, coerce_content_to_string: bool, block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None, - stored_input_usage: Optional[BaseModel] = None, -) -> tuple[ - Optional[AIMessageChunk], - Optional[anthropic.types.RawMessageStreamEvent], - Optional[BaseModel], -]: - """Convert Anthropic event to ``AIMessageChunk``. +) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]: + """Convert Anthropic event to AIMessageChunk. Note that not all events will result in a message chunk. In these cases we return ``None``. - - Args: - event: The Anthropic streaming event to convert. - stream_usage: Whether to include usage metadata in the chunk. - coerce_content_to_string: Whether to coerce content blocks to strings. - block_start_event: Previous content block start event for context. - stored_input_usage: Usage metadata from ``message_start`` event to be used - in ``message_delta`` event for accurate input token counts. - - Returns: - Tuple of ``(message_chunk, block_start_event, stored_usage)`` - """ message_chunk: Optional[AIMessageChunk] = None - updated_stored_usage = stored_input_usage # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 if event.type == "message_start" and stream_usage: - # Store input usage for later use in message_delta but don't emit tokens yet - updated_stored_usage = event.message.usage - usage_metadata = UsageMetadata( - input_tokens=0, - output_tokens=0, - total_tokens=0, + usage_metadata = _create_usage_metadata(event.message.usage) + # We pick up a cumulative count of output_tokens at the end of the stream, + # so here we zero out to avoid double counting. + usage_metadata["total_tokens"] = ( + usage_metadata["total_tokens"] - usage_metadata["output_tokens"] ) + usage_metadata["output_tokens"] = 0 if hasattr(event.message, "model"): response_metadata = {"model_name": event.message.model} else: @@ -2329,37 +2285,11 @@ def _make_message_chunk_from_anthropic_event( tool_call_chunks=tool_call_chunks, ) elif event.type == "message_delta" and stream_usage: - # Create usage metadata combining stored input usage with final output usage - # - # Per Anthropic docs: "The token counts shown in the usage field of the - # message_delta event are cumulative." Thus, when MCP tools are called - # mid-stream, `input_tokens` may be updated with a higher cumulative count. - # We prioritize `event.usage.input_tokens` when available to handle this case. - if stored_input_usage is not None: - # Create a combined usage object that mimics the Anthropic Usage structure - combined_usage = _CombinedUsage( - input_tokens=event.usage.input_tokens - or getattr(stored_input_usage, "input_tokens", 0), - output_tokens=event.usage.output_tokens, - cache_creation_input_tokens=getattr( - stored_input_usage, "cache_creation_input_tokens", None - ), - cache_read_input_tokens=getattr( - stored_input_usage, "cache_read_input_tokens", None - ), - cache_creation=getattr(stored_input_usage, "cache_creation", None) - if hasattr(stored_input_usage, "cache_creation") - else None, - ) - usage_metadata = _create_usage_metadata(combined_usage) - else: - # Fallback to just output tokens if no stored usage - usage_metadata = UsageMetadata( - input_tokens=event.usage.input_tokens or 0, - output_tokens=event.usage.output_tokens, - total_tokens=(event.usage.input_tokens or 0) - + event.usage.output_tokens, - ) + usage_metadata = UsageMetadata( + input_tokens=0, + output_tokens=event.usage.output_tokens, + total_tokens=event.usage.output_tokens, + ) message_chunk = AIMessageChunk( content="", usage_metadata=usage_metadata, @@ -2371,7 +2301,7 @@ def _make_message_chunk_from_anthropic_event( else: pass - return message_chunk, block_start_event, updated_stored_usage + return message_chunk, block_start_event @deprecated(since="0.1.0", removal="1.0.0", alternative="ChatAnthropic") diff --git a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py index 0084d2740ac..8a30427d1c1 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/unit_tests/test_chat_models.py @@ -3,13 +3,12 @@ from __future__ import annotations import os -from types import SimpleNamespace from typing import Any, Callable, Literal, Optional, cast from unittest.mock import MagicMock, patch import anthropic import pytest -from anthropic.types import Message, MessageDeltaUsage, TextBlock, Usage +from anthropic.types import Message, TextBlock, Usage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.runnables import RunnableBinding from langchain_core.tools import BaseTool @@ -23,7 +22,6 @@ from langchain_anthropic.chat_models import ( _create_usage_metadata, _format_image, _format_messages, - _make_message_chunk_from_anthropic_event, _merge_messages, convert_to_anthropic_tool, ) @@ -1215,224 +1213,3 @@ def test_cache_control_kwarg() -> None: ], }, ] - - -def test_streaming_token_counting_deferred() -> None: - """Test streaming defers input token counting until message completion. - - Validates that the streaming implementation correctly: - 1. Stores input tokens from `message_start` without emitting them immediately - 2. Combines stored input tokens with output tokens at `message_delta` completion - 3. Only emits complete token usage metadata when the message is finished - - This prevents the bug where tools would cause inaccurate token counts due to - premature emission of input tokens before tool execution completed. - """ - # Mock `message_start` event with usage - message_start_event = SimpleNamespace( - type="message_start", - message=SimpleNamespace( - usage=Usage( - input_tokens=100, - output_tokens=1, - cache_creation_input_tokens=0, - cache_read_input_tokens=0, - ), - model="claude-opus-4-1-20250805", - ), - ) - - # Mock `message_delta` event with final output tokens - message_delta_event = SimpleNamespace( - type="message_delta", - usage=MessageDeltaUsage( - output_tokens=50, - input_tokens=None, # This is None in real delta events - cache_creation_input_tokens=None, - cache_read_input_tokens=None, - ), - delta=SimpleNamespace( - stop_reason="end_turn", - stop_sequence=None, - ), - ) - - # Test `message_start` event - should store input tokens but not emit them - msg_chunk, _, stored_usage = _make_message_chunk_from_anthropic_event( - message_start_event, # type: ignore[arg-type] - stream_usage=True, - coerce_content_to_string=True, - stored_input_usage=None, - ) - - assert msg_chunk is not None - assert msg_chunk.usage_metadata is not None - - # Input tokens should be 0 at message_start (deferred) - assert msg_chunk.usage_metadata["input_tokens"] == 0 - assert msg_chunk.usage_metadata["output_tokens"] == 0 - assert msg_chunk.usage_metadata["total_tokens"] == 0 - - # Usage should be stored - assert stored_usage is not None - assert getattr(stored_usage, "input_tokens", 0) == 100 - - # Test `message_delta` - combine stored input with delta output tokens - msg_chunk, _, _ = _make_message_chunk_from_anthropic_event( - message_delta_event, # type: ignore[arg-type] - stream_usage=True, - coerce_content_to_string=True, - stored_input_usage=stored_usage, - ) - - assert msg_chunk is not None - assert msg_chunk.usage_metadata is not None - - # Should now have the complete usage metadata - assert msg_chunk.usage_metadata["input_tokens"] == 100 # From stored usage - assert msg_chunk.usage_metadata["output_tokens"] == 50 # From delta event - assert msg_chunk.usage_metadata["total_tokens"] == 150 - - # Verify response metadata is properly set - assert "stop_reason" in msg_chunk.response_metadata - assert msg_chunk.response_metadata["stop_reason"] == "end_turn" - - -def test_streaming_token_counting_fallback() -> None: - """Test streaming token counting gracefully handles missing stored usage. - - Validates that when no stored input usage is available (edge case scenario), - the streaming implementation safely falls back to reporting only output tokens - rather than failing or returning invalid token counts. - """ - # Mock message_delta event without stored input usage - message_delta_event = SimpleNamespace( - type="message_delta", - usage=MessageDeltaUsage( - output_tokens=25, - input_tokens=None, - cache_creation_input_tokens=None, - cache_read_input_tokens=None, - ), - delta=SimpleNamespace( - stop_reason="end_turn", - stop_sequence=None, - ), - ) - - # Test message_delta without stored usage - should fallback gracefully - msg_chunk, _, _ = _make_message_chunk_from_anthropic_event( - message_delta_event, # type: ignore[arg-type] - stream_usage=True, - coerce_content_to_string=True, - stored_input_usage=None, # No stored usage - ) - - assert msg_chunk is not None - assert msg_chunk.usage_metadata is not None - - # Should fallback to 0 input tokens and only report output tokens - assert msg_chunk.usage_metadata["input_tokens"] == 0 - assert msg_chunk.usage_metadata["output_tokens"] == 25 - assert msg_chunk.usage_metadata["total_tokens"] == 25 - - -def test_streaming_token_counting_cumulative_input_tokens() -> None: - """Test streaming handles cumulative input tokens from `message_delta` events. - - Validates that when Anthropic sends updated cumulative input tokens in - `message_delta` events (e.g., due to MCP tool calling), the implementation - prioritizes these updated counts over stored input usage. - - """ - # Mock `message_start` event with initial usage - message_start_event = SimpleNamespace( - type="message_start", - message=SimpleNamespace( - usage=Usage( - input_tokens=100, # Initial input tokens - output_tokens=1, - cache_creation_input_tokens=0, - cache_read_input_tokens=0, - ), - model="claude-opus-4-1-20250805", - ), - ) - - # Mock `message_delta` event with updated cumulative input tokens - # This happens when MCP tools are called mid-stream - message_delta_event = SimpleNamespace( - type="message_delta", - usage=MessageDeltaUsage( - output_tokens=50, - input_tokens=120, # Cumulative count increased due to tool calling - cache_creation_input_tokens=None, - cache_read_input_tokens=None, - ), - delta=SimpleNamespace( - stop_reason="end_turn", - stop_sequence=None, - ), - ) - - # Store input usage from `message_start` - _, _, stored_usage = _make_message_chunk_from_anthropic_event( - message_start_event, # type: ignore[arg-type] - stream_usage=True, - coerce_content_to_string=True, - stored_input_usage=None, - ) - - # Test `message_delta` with cumulative input tokens - msg_chunk, _, _ = _make_message_chunk_from_anthropic_event( - message_delta_event, # type: ignore[arg-type] - stream_usage=True, - coerce_content_to_string=True, - stored_input_usage=stored_usage, - ) - - assert msg_chunk is not None - assert msg_chunk.usage_metadata is not None - - # Should use the cumulative input tokens from event (120) not stored (100) - assert msg_chunk.usage_metadata["input_tokens"] == 120 - assert msg_chunk.usage_metadata["output_tokens"] == 50 - assert msg_chunk.usage_metadata["total_tokens"] == 170 - - -def test_streaming_token_counting_cumulative_fallback() -> None: - """Test fallback handles cumulative input tokens from message_delta events. - - When no stored usage is available, validates that cumulative input tokens - from the message_delta event are still properly used instead of defaulting to 0. - """ - # Mock `message_delta` event with cumulative input tokens but no stored usage - message_delta_event = SimpleNamespace( - type="message_delta", - usage=MessageDeltaUsage( - output_tokens=30, - input_tokens=85, # Cumulative input tokens in the event - cache_creation_input_tokens=None, - cache_read_input_tokens=None, - ), - delta=SimpleNamespace( - stop_reason="end_turn", - stop_sequence=None, - ), - ) - - # Test `message_delta` without stored usage - should use event's input tokens - msg_chunk, _, _ = _make_message_chunk_from_anthropic_event( - message_delta_event, # type: ignore[arg-type] - stream_usage=True, - coerce_content_to_string=True, - stored_input_usage=None, # No stored usage - ) - - assert msg_chunk is not None - assert msg_chunk.usage_metadata is not None - - # Should use cumulative input tokens from event, not fallback to 0 - assert msg_chunk.usage_metadata["input_tokens"] == 85 # From event - assert msg_chunk.usage_metadata["output_tokens"] == 30 - assert msg_chunk.usage_metadata["total_tokens"] == 115