fix(anthropic): correct input_token count for streaming (#32591)

* Create usage metadata on
[`message_delta`](https://docs.anthropic.com/en/docs/build-with-claude/streaming#event-types)
instead of at the beginning. Consequently, token counts are not included
during streaming but instead at the end. This allows for accurate
reporting of server-side tool usage (important for billing)
* Add some clarifying comments
* Fix some outstanding Pylance warnings
* Remove unnecessary `text` popping in thinking blocks
* Also now correctly reports `input_cache_read`/`input_cache_creation`
as a result
This commit is contained in:
Mason Daugherty
2025-08-18 13:51:47 -04:00
committed by GitHub
parent 8042b04da6
commit 8d0fb2d04b
7 changed files with 163 additions and 46 deletions

View File

@@ -2192,47 +2192,65 @@ def _make_message_chunk_from_anthropic_event(
coerce_content_to_string: bool,
block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None,
) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]:
"""Convert Anthropic event to AIMessageChunk.
"""Convert Anthropic streaming event to `AIMessageChunk`.
Args:
event: Raw streaming event from Anthropic SDK
stream_usage: Whether to include usage metadata in the output chunks.
coerce_content_to_string: Whether to convert structured content to plain
text strings. When True, only text content is preserved; when False,
structured content like tool calls and citations are maintained.
block_start_event: Previous content block start event, used for tracking
tool use blocks and maintaining context across related events.
Returns:
Tuple containing:
- AIMessageChunk: Converted message chunk with appropriate content and
metadata, or None if the event doesn't produce a chunk
- RawMessageStreamEvent: Updated `block_start_event` for tracking content
blocks across sequential events, or None if not applicable
Note:
Not all Anthropic events result in message chunks. Events like internal
state changes return None for the message chunk while potentially
updating the `block_start_event` for context tracking.
Note that not all events will result in a message chunk. In these cases
we return ``None``.
"""
message_chunk: Optional[AIMessageChunk] = None
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
# Reference: Anthropic SDK streaming implementation
# https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage:
usage_metadata = _create_usage_metadata(event.message.usage)
# We pick up a cumulative count of output_tokens at the end of the stream,
# so here we zero out to avoid double counting.
usage_metadata["total_tokens"] = (
usage_metadata["total_tokens"] - usage_metadata["output_tokens"]
)
usage_metadata["output_tokens"] = 0
# Capture model name, but don't include usage_metadata yet
# as it will be properly reported in message_delta with complete info
if hasattr(event.message, "model"):
response_metadata = {"model_name": event.message.model}
else:
response_metadata = {}
message_chunk = AIMessageChunk(
content="" if coerce_content_to_string else [],
usage_metadata=usage_metadata,
response_metadata=response_metadata,
)
elif (
event.type == "content_block_start"
and event.content_block is not None
and event.content_block.type
in (
"tool_use",
"code_execution_tool_result",
"tool_use", # Standard tool usage
"code_execution_tool_result", # Built-in code execution results
"document",
"redacted_thinking",
"mcp_tool_use",
"mcp_tool_result",
"server_tool_use",
"web_search_tool_result",
"server_tool_use", # Server-side tool usage
"web_search_tool_result", # Built-in web search results
)
):
if coerce_content_to_string:
warnings.warn("Received unexpected tool content block.", stacklevel=2)
content_block = event.content_block.model_dump()
content_block["index"] = event.index
if event.content_block.type == "tool_use":
@@ -2250,35 +2268,47 @@ def _make_message_chunk_from_anthropic_event(
tool_call_chunks=tool_call_chunks,
)
block_start_event = event
# Process incremental content updates
elif event.type == "content_block_delta":
# Text and citation deltas (incremental text content)
if event.delta.type in ("text_delta", "citations_delta"):
if coerce_content_to_string and hasattr(event.delta, "text"):
text = event.delta.text
text = getattr(event.delta, "text", "")
message_chunk = AIMessageChunk(content=text)
else:
content_block = event.delta.model_dump()
content_block["index"] = event.index
# All citation deltas are part of a text block
content_block["type"] = "text"
if "citation" in content_block:
# Assign citations to a list if present
content_block["citations"] = [content_block.pop("citation")]
message_chunk = AIMessageChunk(content=[content_block])
# Reasoning
elif (
event.delta.type == "thinking_delta"
or event.delta.type == "signature_delta"
):
content_block = event.delta.model_dump()
if "text" in content_block and content_block["text"] is None:
content_block.pop("text")
content_block["index"] = event.index
content_block["type"] = "thinking"
message_chunk = AIMessageChunk(content=[content_block])
# Tool input JSON (streaming tool arguments)
elif event.delta.type == "input_json_delta":
content_block = event.delta.model_dump()
content_block["index"] = event.index
start_event_block = (
getattr(block_start_event, "content_block", None)
if block_start_event
else None
)
if (
(block_start_event is not None)
and hasattr(block_start_event, "content_block")
and (block_start_event.content_block.type == "tool_use")
start_event_block is not None
and getattr(start_event_block, "type", None) == "tool_use"
):
tool_call_chunk = create_tool_call_chunk(
index=event.index,
@@ -2293,12 +2323,10 @@ def _make_message_chunk_from_anthropic_event(
content=[content_block],
tool_call_chunks=tool_call_chunks,
)
# Process final usage metadata and completion info
elif event.type == "message_delta" and stream_usage:
usage_metadata = UsageMetadata(
input_tokens=0,
output_tokens=event.usage.output_tokens,
total_tokens=event.usage.output_tokens,
)
usage_metadata = _create_usage_metadata(event.usage)
message_chunk = AIMessageChunk(
content="",
usage_metadata=usage_metadata,
@@ -2307,6 +2335,8 @@ def _make_message_chunk_from_anthropic_event(
"stop_sequence": event.delta.stop_sequence,
},
)
# Unhandled event types (e.g., `content_block_stop`, `ping` events)
# https://docs.anthropic.com/en/docs/build-with-claude/streaming#other-events
else:
pass
@@ -2319,26 +2349,38 @@ class ChatAnthropicMessages(ChatAnthropic):
def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
"""Create LangChain `UsageMetadata` from Anthropic `Usage` data.
Note: Anthropic's `input_tokens` excludes cached tokens, so we manually add
`cache_read` and `cache_creation` tokens to get the true total.
"""
input_token_details: dict = {
"cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None),
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
}
# Add (beta) cache TTL information if available
# Add cache TTL information if provided (5-minute and 1-hour ephemeral cache)
cache_creation = getattr(anthropic_usage, "cache_creation", None)
cache_creation_keys = ("ephemeral_1h_input_tokens", "ephemeral_5m_input_tokens")
# Currently just copying over the 5m and 1h keys, but if more are added in the
# future we'll need to expand this tuple
cache_creation_keys = ("ephemeral_5m_input_tokens", "ephemeral_1h_input_tokens")
if cache_creation:
if isinstance(cache_creation, BaseModel):
cache_creation = cache_creation.model_dump()
for k in cache_creation_keys:
input_token_details[k] = cache_creation.get(k)
# Anthropic input_tokens exclude cached token counts.
# Calculate total input tokens: Anthropic's `input_tokens` excludes cached tokens,
# so we need to add them back to get the true total input token count
input_tokens = (
(getattr(anthropic_usage, "input_tokens", 0) or 0)
+ (input_token_details["cache_read"] or 0)
+ (input_token_details["cache_creation"] or 0)
(getattr(anthropic_usage, "input_tokens", 0) or 0) # Base input tokens
+ (input_token_details["cache_read"] or 0) # Tokens read from cache
+ (input_token_details["cache_creation"] or 0) # Tokens used to create cache
)
output_tokens = getattr(anthropic_usage, "output_tokens", 0) or 0
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,