mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-22 02:45:49 +00:00
fix(anthropic): streaming token counting to defer input tokens until completion (#32518)
Supersedes #32461 Fixed incorrect input token reporting during streaming when tools are used. Previously, input tokens were counted at `message_start` before tool execution, leading to inaccurate counts. Now input tokens are properly deferred until `message_delta` (completion), aligning with Anthropic's billing model and SDK expectations. **Before Fix:** - Streaming with tools: Input tokens = 0 ❌ - Non-streaming with tools: Input tokens = 472 ✅ **After Fix:** - Streaming with tools: Input tokens = 472 ✅ - Non-streaming with tools: Input tokens = 472 ✅ Aligns with Anthropic's SDK expectations. The SDK handles input token updates in `message_delta` events: ```python # https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py if event.usage.input_tokens is not None: current_snapshot.usage.input_tokens = event.usage.input_tokens ```
This commit is contained in:
parent
2f32c444b8
commit
d3d23e2372
@ -70,6 +70,20 @@ class AnthropicTool(TypedDict):
|
|||||||
cache_control: NotRequired[dict[str, str]]
|
cache_control: NotRequired[dict[str, str]]
|
||||||
|
|
||||||
|
|
||||||
|
class _CombinedUsage(BaseModel):
|
||||||
|
"""Combined usage model for deferred token counting in streaming.
|
||||||
|
|
||||||
|
This mimics the Anthropic Usage structure while combining stored input usage
|
||||||
|
with final output usage for accurate token reporting during streaming.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
cache_creation_input_tokens: Optional[int] = None
|
||||||
|
cache_read_input_tokens: Optional[int] = None
|
||||||
|
cache_creation: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
def _is_builtin_tool(tool: Any) -> bool:
|
def _is_builtin_tool(tool: Any) -> bool:
|
||||||
if not isinstance(tool, dict):
|
if not isinstance(tool, dict):
|
||||||
return False
|
return False
|
||||||
@ -1493,12 +1507,18 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
and not _thinking_in_params(payload)
|
and not _thinking_in_params(payload)
|
||||||
)
|
)
|
||||||
block_start_event = None
|
block_start_event = None
|
||||||
|
stored_input_usage = None
|
||||||
for event in stream:
|
for event in stream:
|
||||||
msg, block_start_event = _make_message_chunk_from_anthropic_event(
|
(
|
||||||
|
msg,
|
||||||
|
block_start_event,
|
||||||
|
stored_input_usage,
|
||||||
|
) = _make_message_chunk_from_anthropic_event(
|
||||||
event,
|
event,
|
||||||
stream_usage=stream_usage,
|
stream_usage=stream_usage,
|
||||||
coerce_content_to_string=coerce_content_to_string,
|
coerce_content_to_string=coerce_content_to_string,
|
||||||
block_start_event=block_start_event,
|
block_start_event=block_start_event,
|
||||||
|
stored_input_usage=stored_input_usage,
|
||||||
)
|
)
|
||||||
if msg is not None:
|
if msg is not None:
|
||||||
chunk = ChatGenerationChunk(message=msg)
|
chunk = ChatGenerationChunk(message=msg)
|
||||||
@ -1529,12 +1549,18 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
and not _thinking_in_params(payload)
|
and not _thinking_in_params(payload)
|
||||||
)
|
)
|
||||||
block_start_event = None
|
block_start_event = None
|
||||||
|
stored_input_usage = None
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
msg, block_start_event = _make_message_chunk_from_anthropic_event(
|
(
|
||||||
|
msg,
|
||||||
|
block_start_event,
|
||||||
|
stored_input_usage,
|
||||||
|
) = _make_message_chunk_from_anthropic_event(
|
||||||
event,
|
event,
|
||||||
stream_usage=stream_usage,
|
stream_usage=stream_usage,
|
||||||
coerce_content_to_string=coerce_content_to_string,
|
coerce_content_to_string=coerce_content_to_string,
|
||||||
block_start_event=block_start_event,
|
block_start_event=block_start_event,
|
||||||
|
stored_input_usage=stored_input_usage,
|
||||||
)
|
)
|
||||||
if msg is not None:
|
if msg is not None:
|
||||||
chunk = ChatGenerationChunk(message=msg)
|
chunk = ChatGenerationChunk(message=msg)
|
||||||
@ -2167,22 +2193,40 @@ def _make_message_chunk_from_anthropic_event(
|
|||||||
stream_usage: bool = True,
|
stream_usage: bool = True,
|
||||||
coerce_content_to_string: bool,
|
coerce_content_to_string: bool,
|
||||||
block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None,
|
block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None,
|
||||||
) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]:
|
stored_input_usage: Optional[BaseModel] = None,
|
||||||
"""Convert Anthropic event to AIMessageChunk.
|
) -> tuple[
|
||||||
|
Optional[AIMessageChunk],
|
||||||
|
Optional[anthropic.types.RawMessageStreamEvent],
|
||||||
|
Optional[BaseModel],
|
||||||
|
]:
|
||||||
|
"""Convert Anthropic event to ``AIMessageChunk``.
|
||||||
|
|
||||||
Note that not all events will result in a message chunk. In these cases
|
Note that not all events will result in a message chunk. In these cases
|
||||||
we return ``None``.
|
we return ``None``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The Anthropic streaming event to convert.
|
||||||
|
stream_usage: Whether to include usage metadata in the chunk.
|
||||||
|
coerce_content_to_string: Whether to coerce content blocks to strings.
|
||||||
|
block_start_event: Previous content block start event for context.
|
||||||
|
stored_input_usage: Usage metadata from ``message_start`` event to be used
|
||||||
|
in ``message_delta`` event for accurate input token counts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of ``(message_chunk, block_start_event, stored_usage)``
|
||||||
|
|
||||||
"""
|
"""
|
||||||
message_chunk: Optional[AIMessageChunk] = None
|
message_chunk: Optional[AIMessageChunk] = None
|
||||||
|
updated_stored_usage = stored_input_usage
|
||||||
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
||||||
if event.type == "message_start" and stream_usage:
|
if event.type == "message_start" and stream_usage:
|
||||||
usage_metadata = _create_usage_metadata(event.message.usage)
|
# Store input usage for later use in message_delta but don't emit tokens yet
|
||||||
# We pick up a cumulative count of output_tokens at the end of the stream,
|
updated_stored_usage = event.message.usage
|
||||||
# so here we zero out to avoid double counting.
|
usage_metadata = UsageMetadata(
|
||||||
usage_metadata["total_tokens"] = (
|
input_tokens=0,
|
||||||
usage_metadata["total_tokens"] - usage_metadata["output_tokens"]
|
output_tokens=0,
|
||||||
|
total_tokens=0,
|
||||||
)
|
)
|
||||||
usage_metadata["output_tokens"] = 0
|
|
||||||
if hasattr(event.message, "model"):
|
if hasattr(event.message, "model"):
|
||||||
response_metadata = {"model_name": event.message.model}
|
response_metadata = {"model_name": event.message.model}
|
||||||
else:
|
else:
|
||||||
@ -2270,11 +2314,37 @@ def _make_message_chunk_from_anthropic_event(
|
|||||||
tool_call_chunks=tool_call_chunks,
|
tool_call_chunks=tool_call_chunks,
|
||||||
)
|
)
|
||||||
elif event.type == "message_delta" and stream_usage:
|
elif event.type == "message_delta" and stream_usage:
|
||||||
usage_metadata = UsageMetadata(
|
# Create usage metadata combining stored input usage with final output usage
|
||||||
input_tokens=0,
|
#
|
||||||
output_tokens=event.usage.output_tokens,
|
# Per Anthropic docs: "The token counts shown in the usage field of the
|
||||||
total_tokens=event.usage.output_tokens,
|
# message_delta event are cumulative." Thus, when MCP tools are called
|
||||||
)
|
# mid-stream, `input_tokens` may be updated with a higher cumulative count.
|
||||||
|
# We prioritize `event.usage.input_tokens` when available to handle this case.
|
||||||
|
if stored_input_usage is not None:
|
||||||
|
# Create a combined usage object that mimics the Anthropic Usage structure
|
||||||
|
combined_usage = _CombinedUsage(
|
||||||
|
input_tokens=event.usage.input_tokens
|
||||||
|
or getattr(stored_input_usage, "input_tokens", 0),
|
||||||
|
output_tokens=event.usage.output_tokens,
|
||||||
|
cache_creation_input_tokens=getattr(
|
||||||
|
stored_input_usage, "cache_creation_input_tokens", None
|
||||||
|
),
|
||||||
|
cache_read_input_tokens=getattr(
|
||||||
|
stored_input_usage, "cache_read_input_tokens", None
|
||||||
|
),
|
||||||
|
cache_creation=getattr(stored_input_usage, "cache_creation", None)
|
||||||
|
if hasattr(stored_input_usage, "cache_creation")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
usage_metadata = _create_usage_metadata(combined_usage)
|
||||||
|
else:
|
||||||
|
# Fallback to just output tokens if no stored usage
|
||||||
|
usage_metadata = UsageMetadata(
|
||||||
|
input_tokens=event.usage.input_tokens or 0,
|
||||||
|
output_tokens=event.usage.output_tokens,
|
||||||
|
total_tokens=(event.usage.input_tokens or 0)
|
||||||
|
+ event.usage.output_tokens,
|
||||||
|
)
|
||||||
message_chunk = AIMessageChunk(
|
message_chunk = AIMessageChunk(
|
||||||
content="",
|
content="",
|
||||||
usage_metadata=usage_metadata,
|
usage_metadata=usage_metadata,
|
||||||
@ -2286,7 +2356,7 @@ def _make_message_chunk_from_anthropic_event(
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return message_chunk, block_start_event
|
return message_chunk, block_start_event, updated_stored_usage
|
||||||
|
|
||||||
|
|
||||||
@deprecated(since="0.1.0", removal="1.0.0", alternative="ChatAnthropic")
|
@deprecated(since="0.1.0", removal="1.0.0", alternative="ChatAnthropic")
|
||||||
|
@ -3,12 +3,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any, Callable, Literal, Optional, cast
|
from typing import Any, Callable, Literal, Optional, cast
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
import pytest
|
import pytest
|
||||||
from anthropic.types import Message, TextBlock, Usage
|
from anthropic.types import Message, MessageDeltaUsage, TextBlock, Usage
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.runnables import RunnableBinding
|
from langchain_core.runnables import RunnableBinding
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@ -22,6 +23,7 @@ from langchain_anthropic.chat_models import (
|
|||||||
_create_usage_metadata,
|
_create_usage_metadata,
|
||||||
_format_image,
|
_format_image,
|
||||||
_format_messages,
|
_format_messages,
|
||||||
|
_make_message_chunk_from_anthropic_event,
|
||||||
_merge_messages,
|
_merge_messages,
|
||||||
convert_to_anthropic_tool,
|
convert_to_anthropic_tool,
|
||||||
)
|
)
|
||||||
@ -1172,3 +1174,224 @@ def test_cache_control_kwarg() -> None:
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_token_counting_deferred() -> None:
|
||||||
|
"""Test streaming defers input token counting until message completion.
|
||||||
|
|
||||||
|
Validates that the streaming implementation correctly:
|
||||||
|
1. Stores input tokens from `message_start` without emitting them immediately
|
||||||
|
2. Combines stored input tokens with output tokens at `message_delta` completion
|
||||||
|
3. Only emits complete token usage metadata when the message is finished
|
||||||
|
|
||||||
|
This prevents the bug where tools would cause inaccurate token counts due to
|
||||||
|
premature emission of input tokens before tool execution completed.
|
||||||
|
"""
|
||||||
|
# Mock `message_start` event with usage
|
||||||
|
message_start_event = SimpleNamespace(
|
||||||
|
type="message_start",
|
||||||
|
message=SimpleNamespace(
|
||||||
|
usage=Usage(
|
||||||
|
input_tokens=100,
|
||||||
|
output_tokens=1,
|
||||||
|
cache_creation_input_tokens=0,
|
||||||
|
cache_read_input_tokens=0,
|
||||||
|
),
|
||||||
|
model="claude-opus-4-1-20250805",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock `message_delta` event with final output tokens
|
||||||
|
message_delta_event = SimpleNamespace(
|
||||||
|
type="message_delta",
|
||||||
|
usage=MessageDeltaUsage(
|
||||||
|
output_tokens=50,
|
||||||
|
input_tokens=None, # This is None in real delta events
|
||||||
|
cache_creation_input_tokens=None,
|
||||||
|
cache_read_input_tokens=None,
|
||||||
|
),
|
||||||
|
delta=SimpleNamespace(
|
||||||
|
stop_reason="end_turn",
|
||||||
|
stop_sequence=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test `message_start` event - should store input tokens but not emit them
|
||||||
|
msg_chunk, _, stored_usage = _make_message_chunk_from_anthropic_event(
|
||||||
|
message_start_event, # type: ignore[arg-type]
|
||||||
|
stream_usage=True,
|
||||||
|
coerce_content_to_string=True,
|
||||||
|
stored_input_usage=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert msg_chunk is not None
|
||||||
|
assert msg_chunk.usage_metadata is not None
|
||||||
|
|
||||||
|
# Input tokens should be 0 at message_start (deferred)
|
||||||
|
assert msg_chunk.usage_metadata["input_tokens"] == 0
|
||||||
|
assert msg_chunk.usage_metadata["output_tokens"] == 0
|
||||||
|
assert msg_chunk.usage_metadata["total_tokens"] == 0
|
||||||
|
|
||||||
|
# Usage should be stored
|
||||||
|
assert stored_usage is not None
|
||||||
|
assert getattr(stored_usage, "input_tokens", 0) == 100
|
||||||
|
|
||||||
|
# Test `message_delta` - combine stored input with delta output tokens
|
||||||
|
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
|
||||||
|
message_delta_event, # type: ignore[arg-type]
|
||||||
|
stream_usage=True,
|
||||||
|
coerce_content_to_string=True,
|
||||||
|
stored_input_usage=stored_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert msg_chunk is not None
|
||||||
|
assert msg_chunk.usage_metadata is not None
|
||||||
|
|
||||||
|
# Should now have the complete usage metadata
|
||||||
|
assert msg_chunk.usage_metadata["input_tokens"] == 100 # From stored usage
|
||||||
|
assert msg_chunk.usage_metadata["output_tokens"] == 50 # From delta event
|
||||||
|
assert msg_chunk.usage_metadata["total_tokens"] == 150
|
||||||
|
|
||||||
|
# Verify response metadata is properly set
|
||||||
|
assert "stop_reason" in msg_chunk.response_metadata
|
||||||
|
assert msg_chunk.response_metadata["stop_reason"] == "end_turn"
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_token_counting_fallback() -> None:
|
||||||
|
"""Test streaming token counting gracefully handles missing stored usage.
|
||||||
|
|
||||||
|
Validates that when no stored input usage is available (edge case scenario),
|
||||||
|
the streaming implementation safely falls back to reporting only output tokens
|
||||||
|
rather than failing or returning invalid token counts.
|
||||||
|
"""
|
||||||
|
# Mock message_delta event without stored input usage
|
||||||
|
message_delta_event = SimpleNamespace(
|
||||||
|
type="message_delta",
|
||||||
|
usage=MessageDeltaUsage(
|
||||||
|
output_tokens=25,
|
||||||
|
input_tokens=None,
|
||||||
|
cache_creation_input_tokens=None,
|
||||||
|
cache_read_input_tokens=None,
|
||||||
|
),
|
||||||
|
delta=SimpleNamespace(
|
||||||
|
stop_reason="end_turn",
|
||||||
|
stop_sequence=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test message_delta without stored usage - should fallback gracefully
|
||||||
|
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
|
||||||
|
message_delta_event, # type: ignore[arg-type]
|
||||||
|
stream_usage=True,
|
||||||
|
coerce_content_to_string=True,
|
||||||
|
stored_input_usage=None, # No stored usage
|
||||||
|
)
|
||||||
|
|
||||||
|
assert msg_chunk is not None
|
||||||
|
assert msg_chunk.usage_metadata is not None
|
||||||
|
|
||||||
|
# Should fallback to 0 input tokens and only report output tokens
|
||||||
|
assert msg_chunk.usage_metadata["input_tokens"] == 0
|
||||||
|
assert msg_chunk.usage_metadata["output_tokens"] == 25
|
||||||
|
assert msg_chunk.usage_metadata["total_tokens"] == 25
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_token_counting_cumulative_input_tokens() -> None:
|
||||||
|
"""Test streaming handles cumulative input tokens from `message_delta` events.
|
||||||
|
|
||||||
|
Validates that when Anthropic sends updated cumulative input tokens in
|
||||||
|
`message_delta` events (e.g., due to MCP tool calling), the implementation
|
||||||
|
prioritizes these updated counts over stored input usage.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Mock `message_start` event with initial usage
|
||||||
|
message_start_event = SimpleNamespace(
|
||||||
|
type="message_start",
|
||||||
|
message=SimpleNamespace(
|
||||||
|
usage=Usage(
|
||||||
|
input_tokens=100, # Initial input tokens
|
||||||
|
output_tokens=1,
|
||||||
|
cache_creation_input_tokens=0,
|
||||||
|
cache_read_input_tokens=0,
|
||||||
|
),
|
||||||
|
model="claude-opus-4-1-20250805",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock `message_delta` event with updated cumulative input tokens
|
||||||
|
# This happens when MCP tools are called mid-stream
|
||||||
|
message_delta_event = SimpleNamespace(
|
||||||
|
type="message_delta",
|
||||||
|
usage=MessageDeltaUsage(
|
||||||
|
output_tokens=50,
|
||||||
|
input_tokens=120, # Cumulative count increased due to tool calling
|
||||||
|
cache_creation_input_tokens=None,
|
||||||
|
cache_read_input_tokens=None,
|
||||||
|
),
|
||||||
|
delta=SimpleNamespace(
|
||||||
|
stop_reason="end_turn",
|
||||||
|
stop_sequence=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store input usage from `message_start`
|
||||||
|
_, _, stored_usage = _make_message_chunk_from_anthropic_event(
|
||||||
|
message_start_event, # type: ignore[arg-type]
|
||||||
|
stream_usage=True,
|
||||||
|
coerce_content_to_string=True,
|
||||||
|
stored_input_usage=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test `message_delta` with cumulative input tokens
|
||||||
|
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
|
||||||
|
message_delta_event, # type: ignore[arg-type]
|
||||||
|
stream_usage=True,
|
||||||
|
coerce_content_to_string=True,
|
||||||
|
stored_input_usage=stored_usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert msg_chunk is not None
|
||||||
|
assert msg_chunk.usage_metadata is not None
|
||||||
|
|
||||||
|
# Should use the cumulative input tokens from event (120) not stored (100)
|
||||||
|
assert msg_chunk.usage_metadata["input_tokens"] == 120
|
||||||
|
assert msg_chunk.usage_metadata["output_tokens"] == 50
|
||||||
|
assert msg_chunk.usage_metadata["total_tokens"] == 170
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_token_counting_cumulative_fallback() -> None:
|
||||||
|
"""Test fallback handles cumulative input tokens from message_delta events.
|
||||||
|
|
||||||
|
When no stored usage is available, validates that cumulative input tokens
|
||||||
|
from the message_delta event are still properly used instead of defaulting to 0.
|
||||||
|
"""
|
||||||
|
# Mock `message_delta` event with cumulative input tokens but no stored usage
|
||||||
|
message_delta_event = SimpleNamespace(
|
||||||
|
type="message_delta",
|
||||||
|
usage=MessageDeltaUsage(
|
||||||
|
output_tokens=30,
|
||||||
|
input_tokens=85, # Cumulative input tokens in the event
|
||||||
|
cache_creation_input_tokens=None,
|
||||||
|
cache_read_input_tokens=None,
|
||||||
|
),
|
||||||
|
delta=SimpleNamespace(
|
||||||
|
stop_reason="end_turn",
|
||||||
|
stop_sequence=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test `message_delta` without stored usage - should use event's input tokens
|
||||||
|
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
|
||||||
|
message_delta_event, # type: ignore[arg-type]
|
||||||
|
stream_usage=True,
|
||||||
|
coerce_content_to_string=True,
|
||||||
|
stored_input_usage=None, # No stored usage
|
||||||
|
)
|
||||||
|
|
||||||
|
assert msg_chunk is not None
|
||||||
|
assert msg_chunk.usage_metadata is not None
|
||||||
|
|
||||||
|
# Should use cumulative input tokens from event, not fallback to 0
|
||||||
|
assert msg_chunk.usage_metadata["input_tokens"] == 85 # From event
|
||||||
|
assert msg_chunk.usage_metadata["output_tokens"] == 30
|
||||||
|
assert msg_chunk.usage_metadata["total_tokens"] == 115
|
||||||
|
Loading…
Reference in New Issue
Block a user