mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-22 02:45:49 +00:00
revert(anthropic): streaming token counting to defer input tokens until completion (#32587)
Reverts langchain-ai/langchain#32518
This commit is contained in:
parent
b8cdbc4eca
commit
fd891ee3d4
@ -70,20 +70,6 @@ class AnthropicTool(TypedDict):
|
||||
cache_control: NotRequired[dict[str, str]]
|
||||
|
||||
|
||||
class _CombinedUsage(BaseModel):
|
||||
"""Combined usage model for deferred token counting in streaming.
|
||||
|
||||
This mimics the Anthropic Usage structure while combining stored input usage
|
||||
with final output usage for accurate token reporting during streaming.
|
||||
"""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_input_tokens: Optional[int] = None
|
||||
cache_read_input_tokens: Optional[int] = None
|
||||
cache_creation: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
def _is_builtin_tool(tool: Any) -> bool:
|
||||
if not isinstance(tool, dict):
|
||||
return False
|
||||
@ -1522,18 +1508,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
and not _thinking_in_params(payload)
|
||||
)
|
||||
block_start_event = None
|
||||
stored_input_usage = None
|
||||
for event in stream:
|
||||
(
|
||||
msg,
|
||||
block_start_event,
|
||||
stored_input_usage,
|
||||
) = _make_message_chunk_from_anthropic_event(
|
||||
msg, block_start_event = _make_message_chunk_from_anthropic_event(
|
||||
event,
|
||||
stream_usage=stream_usage,
|
||||
coerce_content_to_string=coerce_content_to_string,
|
||||
block_start_event=block_start_event,
|
||||
stored_input_usage=stored_input_usage,
|
||||
)
|
||||
if msg is not None:
|
||||
chunk = ChatGenerationChunk(message=msg)
|
||||
@ -1564,18 +1544,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
and not _thinking_in_params(payload)
|
||||
)
|
||||
block_start_event = None
|
||||
stored_input_usage = None
|
||||
async for event in stream:
|
||||
(
|
||||
msg,
|
||||
block_start_event,
|
||||
stored_input_usage,
|
||||
) = _make_message_chunk_from_anthropic_event(
|
||||
msg, block_start_event = _make_message_chunk_from_anthropic_event(
|
||||
event,
|
||||
stream_usage=stream_usage,
|
||||
coerce_content_to_string=coerce_content_to_string,
|
||||
block_start_event=block_start_event,
|
||||
stored_input_usage=stored_input_usage,
|
||||
)
|
||||
if msg is not None:
|
||||
chunk = ChatGenerationChunk(message=msg)
|
||||
@ -2208,40 +2182,22 @@ def _make_message_chunk_from_anthropic_event(
|
||||
stream_usage: bool = True,
|
||||
coerce_content_to_string: bool,
|
||||
block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None,
|
||||
stored_input_usage: Optional[BaseModel] = None,
|
||||
) -> tuple[
|
||||
Optional[AIMessageChunk],
|
||||
Optional[anthropic.types.RawMessageStreamEvent],
|
||||
Optional[BaseModel],
|
||||
]:
|
||||
"""Convert Anthropic event to ``AIMessageChunk``.
|
||||
) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]:
|
||||
"""Convert Anthropic event to AIMessageChunk.
|
||||
|
||||
Note that not all events will result in a message chunk. In these cases
|
||||
we return ``None``.
|
||||
|
||||
Args:
|
||||
event: The Anthropic streaming event to convert.
|
||||
stream_usage: Whether to include usage metadata in the chunk.
|
||||
coerce_content_to_string: Whether to coerce content blocks to strings.
|
||||
block_start_event: Previous content block start event for context.
|
||||
stored_input_usage: Usage metadata from ``message_start`` event to be used
|
||||
in ``message_delta`` event for accurate input token counts.
|
||||
|
||||
Returns:
|
||||
Tuple of ``(message_chunk, block_start_event, stored_usage)``
|
||||
|
||||
"""
|
||||
message_chunk: Optional[AIMessageChunk] = None
|
||||
updated_stored_usage = stored_input_usage
|
||||
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
||||
if event.type == "message_start" and stream_usage:
|
||||
# Store input usage for later use in message_delta but don't emit tokens yet
|
||||
updated_stored_usage = event.message.usage
|
||||
usage_metadata = UsageMetadata(
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
total_tokens=0,
|
||||
usage_metadata = _create_usage_metadata(event.message.usage)
|
||||
# We pick up a cumulative count of output_tokens at the end of the stream,
|
||||
# so here we zero out to avoid double counting.
|
||||
usage_metadata["total_tokens"] = (
|
||||
usage_metadata["total_tokens"] - usage_metadata["output_tokens"]
|
||||
)
|
||||
usage_metadata["output_tokens"] = 0
|
||||
if hasattr(event.message, "model"):
|
||||
response_metadata = {"model_name": event.message.model}
|
||||
else:
|
||||
@ -2329,37 +2285,11 @@ def _make_message_chunk_from_anthropic_event(
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
elif event.type == "message_delta" and stream_usage:
|
||||
# Create usage metadata combining stored input usage with final output usage
|
||||
#
|
||||
# Per Anthropic docs: "The token counts shown in the usage field of the
|
||||
# message_delta event are cumulative." Thus, when MCP tools are called
|
||||
# mid-stream, `input_tokens` may be updated with a higher cumulative count.
|
||||
# We prioritize `event.usage.input_tokens` when available to handle this case.
|
||||
if stored_input_usage is not None:
|
||||
# Create a combined usage object that mimics the Anthropic Usage structure
|
||||
combined_usage = _CombinedUsage(
|
||||
input_tokens=event.usage.input_tokens
|
||||
or getattr(stored_input_usage, "input_tokens", 0),
|
||||
output_tokens=event.usage.output_tokens,
|
||||
cache_creation_input_tokens=getattr(
|
||||
stored_input_usage, "cache_creation_input_tokens", None
|
||||
),
|
||||
cache_read_input_tokens=getattr(
|
||||
stored_input_usage, "cache_read_input_tokens", None
|
||||
),
|
||||
cache_creation=getattr(stored_input_usage, "cache_creation", None)
|
||||
if hasattr(stored_input_usage, "cache_creation")
|
||||
else None,
|
||||
)
|
||||
usage_metadata = _create_usage_metadata(combined_usage)
|
||||
else:
|
||||
# Fallback to just output tokens if no stored usage
|
||||
usage_metadata = UsageMetadata(
|
||||
input_tokens=event.usage.input_tokens or 0,
|
||||
output_tokens=event.usage.output_tokens,
|
||||
total_tokens=(event.usage.input_tokens or 0)
|
||||
+ event.usage.output_tokens,
|
||||
)
|
||||
usage_metadata = UsageMetadata(
|
||||
input_tokens=0,
|
||||
output_tokens=event.usage.output_tokens,
|
||||
total_tokens=event.usage.output_tokens,
|
||||
)
|
||||
message_chunk = AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=usage_metadata,
|
||||
@ -2371,7 +2301,7 @@ def _make_message_chunk_from_anthropic_event(
|
||||
else:
|
||||
pass
|
||||
|
||||
return message_chunk, block_start_event, updated_stored_usage
|
||||
return message_chunk, block_start_event
|
||||
|
||||
|
||||
@deprecated(since="0.1.0", removal="1.0.0", alternative="ChatAnthropic")
|
||||
|
@ -3,13 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Callable, Literal, Optional, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import anthropic
|
||||
import pytest
|
||||
from anthropic.types import Message, MessageDeltaUsage, TextBlock, Usage
|
||||
from anthropic.types import Message, TextBlock, Usage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_core.tools import BaseTool
|
||||
@ -23,7 +22,6 @@ from langchain_anthropic.chat_models import (
|
||||
_create_usage_metadata,
|
||||
_format_image,
|
||||
_format_messages,
|
||||
_make_message_chunk_from_anthropic_event,
|
||||
_merge_messages,
|
||||
convert_to_anthropic_tool,
|
||||
)
|
||||
@ -1215,224 +1213,3 @@ def test_cache_control_kwarg() -> None:
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_streaming_token_counting_deferred() -> None:
|
||||
"""Test streaming defers input token counting until message completion.
|
||||
|
||||
Validates that the streaming implementation correctly:
|
||||
1. Stores input tokens from `message_start` without emitting them immediately
|
||||
2. Combines stored input tokens with output tokens at `message_delta` completion
|
||||
3. Only emits complete token usage metadata when the message is finished
|
||||
|
||||
This prevents the bug where tools would cause inaccurate token counts due to
|
||||
premature emission of input tokens before tool execution completed.
|
||||
"""
|
||||
# Mock `message_start` event with usage
|
||||
message_start_event = SimpleNamespace(
|
||||
type="message_start",
|
||||
message=SimpleNamespace(
|
||||
usage=Usage(
|
||||
input_tokens=100,
|
||||
output_tokens=1,
|
||||
cache_creation_input_tokens=0,
|
||||
cache_read_input_tokens=0,
|
||||
),
|
||||
model="claude-opus-4-1-20250805",
|
||||
),
|
||||
)
|
||||
|
||||
# Mock `message_delta` event with final output tokens
|
||||
message_delta_event = SimpleNamespace(
|
||||
type="message_delta",
|
||||
usage=MessageDeltaUsage(
|
||||
output_tokens=50,
|
||||
input_tokens=None, # This is None in real delta events
|
||||
cache_creation_input_tokens=None,
|
||||
cache_read_input_tokens=None,
|
||||
),
|
||||
delta=SimpleNamespace(
|
||||
stop_reason="end_turn",
|
||||
stop_sequence=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Test `message_start` event - should store input tokens but not emit them
|
||||
msg_chunk, _, stored_usage = _make_message_chunk_from_anthropic_event(
|
||||
message_start_event, # type: ignore[arg-type]
|
||||
stream_usage=True,
|
||||
coerce_content_to_string=True,
|
||||
stored_input_usage=None,
|
||||
)
|
||||
|
||||
assert msg_chunk is not None
|
||||
assert msg_chunk.usage_metadata is not None
|
||||
|
||||
# Input tokens should be 0 at message_start (deferred)
|
||||
assert msg_chunk.usage_metadata["input_tokens"] == 0
|
||||
assert msg_chunk.usage_metadata["output_tokens"] == 0
|
||||
assert msg_chunk.usage_metadata["total_tokens"] == 0
|
||||
|
||||
# Usage should be stored
|
||||
assert stored_usage is not None
|
||||
assert getattr(stored_usage, "input_tokens", 0) == 100
|
||||
|
||||
# Test `message_delta` - combine stored input with delta output tokens
|
||||
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
|
||||
message_delta_event, # type: ignore[arg-type]
|
||||
stream_usage=True,
|
||||
coerce_content_to_string=True,
|
||||
stored_input_usage=stored_usage,
|
||||
)
|
||||
|
||||
assert msg_chunk is not None
|
||||
assert msg_chunk.usage_metadata is not None
|
||||
|
||||
# Should now have the complete usage metadata
|
||||
assert msg_chunk.usage_metadata["input_tokens"] == 100 # From stored usage
|
||||
assert msg_chunk.usage_metadata["output_tokens"] == 50 # From delta event
|
||||
assert msg_chunk.usage_metadata["total_tokens"] == 150
|
||||
|
||||
# Verify response metadata is properly set
|
||||
assert "stop_reason" in msg_chunk.response_metadata
|
||||
assert msg_chunk.response_metadata["stop_reason"] == "end_turn"
|
||||
|
||||
|
||||
def test_streaming_token_counting_fallback() -> None:
|
||||
"""Test streaming token counting gracefully handles missing stored usage.
|
||||
|
||||
Validates that when no stored input usage is available (edge case scenario),
|
||||
the streaming implementation safely falls back to reporting only output tokens
|
||||
rather than failing or returning invalid token counts.
|
||||
"""
|
||||
# Mock message_delta event without stored input usage
|
||||
message_delta_event = SimpleNamespace(
|
||||
type="message_delta",
|
||||
usage=MessageDeltaUsage(
|
||||
output_tokens=25,
|
||||
input_tokens=None,
|
||||
cache_creation_input_tokens=None,
|
||||
cache_read_input_tokens=None,
|
||||
),
|
||||
delta=SimpleNamespace(
|
||||
stop_reason="end_turn",
|
||||
stop_sequence=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Test message_delta without stored usage - should fallback gracefully
|
||||
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
|
||||
message_delta_event, # type: ignore[arg-type]
|
||||
stream_usage=True,
|
||||
coerce_content_to_string=True,
|
||||
stored_input_usage=None, # No stored usage
|
||||
)
|
||||
|
||||
assert msg_chunk is not None
|
||||
assert msg_chunk.usage_metadata is not None
|
||||
|
||||
# Should fallback to 0 input tokens and only report output tokens
|
||||
assert msg_chunk.usage_metadata["input_tokens"] == 0
|
||||
assert msg_chunk.usage_metadata["output_tokens"] == 25
|
||||
assert msg_chunk.usage_metadata["total_tokens"] == 25
|
||||
|
||||
|
||||
def test_streaming_token_counting_cumulative_input_tokens() -> None:
|
||||
"""Test streaming handles cumulative input tokens from `message_delta` events.
|
||||
|
||||
Validates that when Anthropic sends updated cumulative input tokens in
|
||||
`message_delta` events (e.g., due to MCP tool calling), the implementation
|
||||
prioritizes these updated counts over stored input usage.
|
||||
|
||||
"""
|
||||
# Mock `message_start` event with initial usage
|
||||
message_start_event = SimpleNamespace(
|
||||
type="message_start",
|
||||
message=SimpleNamespace(
|
||||
usage=Usage(
|
||||
input_tokens=100, # Initial input tokens
|
||||
output_tokens=1,
|
||||
cache_creation_input_tokens=0,
|
||||
cache_read_input_tokens=0,
|
||||
),
|
||||
model="claude-opus-4-1-20250805",
|
||||
),
|
||||
)
|
||||
|
||||
# Mock `message_delta` event with updated cumulative input tokens
|
||||
# This happens when MCP tools are called mid-stream
|
||||
message_delta_event = SimpleNamespace(
|
||||
type="message_delta",
|
||||
usage=MessageDeltaUsage(
|
||||
output_tokens=50,
|
||||
input_tokens=120, # Cumulative count increased due to tool calling
|
||||
cache_creation_input_tokens=None,
|
||||
cache_read_input_tokens=None,
|
||||
),
|
||||
delta=SimpleNamespace(
|
||||
stop_reason="end_turn",
|
||||
stop_sequence=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Store input usage from `message_start`
|
||||
_, _, stored_usage = _make_message_chunk_from_anthropic_event(
|
||||
message_start_event, # type: ignore[arg-type]
|
||||
stream_usage=True,
|
||||
coerce_content_to_string=True,
|
||||
stored_input_usage=None,
|
||||
)
|
||||
|
||||
# Test `message_delta` with cumulative input tokens
|
||||
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
|
||||
message_delta_event, # type: ignore[arg-type]
|
||||
stream_usage=True,
|
||||
coerce_content_to_string=True,
|
||||
stored_input_usage=stored_usage,
|
||||
)
|
||||
|
||||
assert msg_chunk is not None
|
||||
assert msg_chunk.usage_metadata is not None
|
||||
|
||||
# Should use the cumulative input tokens from event (120) not stored (100)
|
||||
assert msg_chunk.usage_metadata["input_tokens"] == 120
|
||||
assert msg_chunk.usage_metadata["output_tokens"] == 50
|
||||
assert msg_chunk.usage_metadata["total_tokens"] == 170
|
||||
|
||||
|
||||
def test_streaming_token_counting_cumulative_fallback() -> None:
|
||||
"""Test fallback handles cumulative input tokens from message_delta events.
|
||||
|
||||
When no stored usage is available, validates that cumulative input tokens
|
||||
from the message_delta event are still properly used instead of defaulting to 0.
|
||||
"""
|
||||
# Mock `message_delta` event with cumulative input tokens but no stored usage
|
||||
message_delta_event = SimpleNamespace(
|
||||
type="message_delta",
|
||||
usage=MessageDeltaUsage(
|
||||
output_tokens=30,
|
||||
input_tokens=85, # Cumulative input tokens in the event
|
||||
cache_creation_input_tokens=None,
|
||||
cache_read_input_tokens=None,
|
||||
),
|
||||
delta=SimpleNamespace(
|
||||
stop_reason="end_turn",
|
||||
stop_sequence=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Test `message_delta` without stored usage - should use event's input tokens
|
||||
msg_chunk, _, _ = _make_message_chunk_from_anthropic_event(
|
||||
message_delta_event, # type: ignore[arg-type]
|
||||
stream_usage=True,
|
||||
coerce_content_to_string=True,
|
||||
stored_input_usage=None, # No stored usage
|
||||
)
|
||||
|
||||
assert msg_chunk is not None
|
||||
assert msg_chunk.usage_metadata is not None
|
||||
|
||||
# Should use cumulative input tokens from event, not fallback to 0
|
||||
assert msg_chunk.usage_metadata["input_tokens"] == 85 # From event
|
||||
assert msg_chunk.usage_metadata["output_tokens"] == 30
|
||||
assert msg_chunk.usage_metadata["total_tokens"] == 115
|
||||
|
Loading…
Reference in New Issue
Block a user