fix(anthropic): correct input_token count for streaming (#32591)

* Create usage metadata on
[`message_delta`](https://docs.anthropic.com/en/docs/build-with-claude/streaming#event-types)
instead of at the beginning. Consequently, token counts are not included
during streaming but instead at the end. This allows for accurate
reporting of server-side tool usage (important for billing)
* Add some clarifying comments
* Fix some outstanding Pylance warnings
* Remove unnecessary `text` popping in thinking blocks
* Also now correctly reports `input_cache_read`/`input_cache_creation`
as a result
This commit is contained in:
Mason Daugherty 2025-08-18 13:51:47 -04:00 committed by GitHub
parent 8042b04da6
commit 8d0fb2d04b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 163 additions and 46 deletions

View File

@ -2192,47 +2192,65 @@ def _make_message_chunk_from_anthropic_event(
coerce_content_to_string: bool,
block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None,
) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]:
"""Convert Anthropic event to AIMessageChunk.
"""Convert Anthropic streaming event to `AIMessageChunk`.
Args:
event: Raw streaming event from Anthropic SDK
stream_usage: Whether to include usage metadata in the output chunks.
coerce_content_to_string: Whether to convert structured content to plain
text strings. When True, only text content is preserved; when False,
structured content like tool calls and citations are maintained.
block_start_event: Previous content block start event, used for tracking
tool use blocks and maintaining context across related events.
Returns:
Tuple containing:
- AIMessageChunk: Converted message chunk with appropriate content and
metadata, or None if the event doesn't produce a chunk
- RawMessageStreamEvent: Updated `block_start_event` for tracking content
blocks across sequential events, or None if not applicable
Note:
Not all Anthropic events result in message chunks. Events like internal
state changes return None for the message chunk while potentially
updating the `block_start_event` for context tracking.
Note that not all events will result in a message chunk. In these cases
we return ``None``.
"""
message_chunk: Optional[AIMessageChunk] = None
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
# Reference: Anthropic SDK streaming implementation
# https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage:
usage_metadata = _create_usage_metadata(event.message.usage)
# We pick up a cumulative count of output_tokens at the end of the stream,
# so here we zero out to avoid double counting.
usage_metadata["total_tokens"] = (
usage_metadata["total_tokens"] - usage_metadata["output_tokens"]
)
usage_metadata["output_tokens"] = 0
# Capture model name, but don't include usage_metadata yet
# as it will be properly reported in message_delta with complete info
if hasattr(event.message, "model"):
response_metadata = {"model_name": event.message.model}
else:
response_metadata = {}
message_chunk = AIMessageChunk(
content="" if coerce_content_to_string else [],
usage_metadata=usage_metadata,
response_metadata=response_metadata,
)
elif (
event.type == "content_block_start"
and event.content_block is not None
and event.content_block.type
in (
"tool_use",
"code_execution_tool_result",
"tool_use", # Standard tool usage
"code_execution_tool_result", # Built-in code execution results
"document",
"redacted_thinking",
"mcp_tool_use",
"mcp_tool_result",
"server_tool_use",
"web_search_tool_result",
"server_tool_use", # Server-side tool usage
"web_search_tool_result", # Built-in web search results
)
):
if coerce_content_to_string:
warnings.warn("Received unexpected tool content block.", stacklevel=2)
content_block = event.content_block.model_dump()
content_block["index"] = event.index
if event.content_block.type == "tool_use":
@ -2250,35 +2268,47 @@ def _make_message_chunk_from_anthropic_event(
tool_call_chunks=tool_call_chunks,
)
block_start_event = event
# Process incremental content updates
elif event.type == "content_block_delta":
# Text and citation deltas (incremental text content)
if event.delta.type in ("text_delta", "citations_delta"):
if coerce_content_to_string and hasattr(event.delta, "text"):
text = event.delta.text
text = getattr(event.delta, "text", "")
message_chunk = AIMessageChunk(content=text)
else:
content_block = event.delta.model_dump()
content_block["index"] = event.index
# All citation deltas are part of a text block
content_block["type"] = "text"
if "citation" in content_block:
# Assign citations to a list if present
content_block["citations"] = [content_block.pop("citation")]
message_chunk = AIMessageChunk(content=[content_block])
# Reasoning
elif (
event.delta.type == "thinking_delta"
or event.delta.type == "signature_delta"
):
content_block = event.delta.model_dump()
if "text" in content_block and content_block["text"] is None:
content_block.pop("text")
content_block["index"] = event.index
content_block["type"] = "thinking"
message_chunk = AIMessageChunk(content=[content_block])
# Tool input JSON (streaming tool arguments)
elif event.delta.type == "input_json_delta":
content_block = event.delta.model_dump()
content_block["index"] = event.index
start_event_block = (
getattr(block_start_event, "content_block", None)
if block_start_event
else None
)
if (
(block_start_event is not None)
and hasattr(block_start_event, "content_block")
and (block_start_event.content_block.type == "tool_use")
start_event_block is not None
and getattr(start_event_block, "type", None) == "tool_use"
):
tool_call_chunk = create_tool_call_chunk(
index=event.index,
@ -2293,12 +2323,10 @@ def _make_message_chunk_from_anthropic_event(
content=[content_block],
tool_call_chunks=tool_call_chunks,
)
# Process final usage metadata and completion info
elif event.type == "message_delta" and stream_usage:
usage_metadata = UsageMetadata(
input_tokens=0,
output_tokens=event.usage.output_tokens,
total_tokens=event.usage.output_tokens,
)
usage_metadata = _create_usage_metadata(event.usage)
message_chunk = AIMessageChunk(
content="",
usage_metadata=usage_metadata,
@ -2307,6 +2335,8 @@ def _make_message_chunk_from_anthropic_event(
"stop_sequence": event.delta.stop_sequence,
},
)
# Unhandled event types (e.g., `content_block_stop`, `ping` events)
# https://docs.anthropic.com/en/docs/build-with-claude/streaming#other-events
else:
pass
@ -2319,26 +2349,38 @@ class ChatAnthropicMessages(ChatAnthropic):
def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
"""Create LangChain `UsageMetadata` from Anthropic `Usage` data.
Note: Anthropic's `input_tokens` excludes cached tokens, so we manually add
`cache_read` and `cache_creation` tokens to get the true total.
"""
input_token_details: dict = {
"cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None),
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
}
# Add (beta) cache TTL information if available
# Add cache TTL information if provided (5-minute and 1-hour ephemeral cache)
cache_creation = getattr(anthropic_usage, "cache_creation", None)
cache_creation_keys = ("ephemeral_1h_input_tokens", "ephemeral_5m_input_tokens")
# Currently just copying over the 5m and 1h keys, but if more are added in the
# future we'll need to expand this tuple
cache_creation_keys = ("ephemeral_5m_input_tokens", "ephemeral_1h_input_tokens")
if cache_creation:
if isinstance(cache_creation, BaseModel):
cache_creation = cache_creation.model_dump()
for k in cache_creation_keys:
input_token_details[k] = cache_creation.get(k)
# Anthropic input_tokens exclude cached token counts.
# Calculate total input tokens: Anthropic's `input_tokens` excludes cached tokens,
# so we need to add them back to get the true total input token count
input_tokens = (
(getattr(anthropic_usage, "input_tokens", 0) or 0)
+ (input_token_details["cache_read"] or 0)
+ (input_token_details["cache_creation"] or 0)
(getattr(anthropic_usage, "input_tokens", 0) or 0) # Base input tokens
+ (input_token_details["cache_read"] or 0) # Tokens read from cache
+ (input_token_details["cache_creation"] or 0) # Tokens used to create cache
)
output_tokens = getattr(anthropic_usage, "output_tokens", 0) or 0
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,

View File

@ -7,7 +7,7 @@ authors = []
license = { text = "MIT" }
requires-python = ">=3.9"
dependencies = [
"anthropic<1,>=0.60.0",
"anthropic<1,>=0.64.0",
"langchain-core<1.0.0,>=0.3.72",
"pydantic<3.0.0,>=2.7.4",
]

View File

@ -9,6 +9,8 @@ from langchain_core.outputs import LLMResult
from langchain_anthropic import Anthropic
from tests.unit_tests._utils import FakeCallbackHandler
MODEL = "claude-3-7-sonnet-latest"
@pytest.mark.requires("anthropic")
def test_anthropic_model_name_param() -> None:
@ -24,14 +26,14 @@ def test_anthropic_model_param() -> None:
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
llm = Anthropic(model=MODEL) # type: ignore[call-arg]
output = llm.invoke("Say foo:")
assert isinstance(output, str)
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
llm = Anthropic(model=MODEL) # type: ignore[call-arg]
generator = llm.stream("I'm Pickle Rick")
assert isinstance(generator, Generator)
@ -45,7 +47,7 @@ def test_anthropic_streaming_callback() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Anthropic(
model="claude-3-7-sonnet-20250219", # type: ignore[call-arg]
model=MODEL, # type: ignore[call-arg]
streaming=True,
callback_manager=callback_manager,
verbose=True,
@ -56,7 +58,7 @@ def test_anthropic_streaming_callback() -> None:
async def test_anthropic_async_generate() -> None:
"""Test async generate."""
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
llm = Anthropic(model=MODEL) # type: ignore[call-arg]
output = await llm.agenerate(["How many toes do dogs have?"])
assert isinstance(output, LLMResult)
@ -66,7 +68,7 @@ async def test_anthropic_async_streaming_callback() -> None:
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Anthropic(
model="claude-3-7-sonnet-20250219", # type: ignore[call-arg]
model=MODEL, # type: ignore[call-arg]
streaming=True,
callback_manager=callback_manager,
verbose=True,

View File

@ -11,6 +11,8 @@ from langchain_anthropic import ChatAnthropic
REPO_ROOT_DIR = Path(__file__).parents[5]
MODEL = "claude-3-5-haiku-latest"
class TestAnthropicStandard(ChatModelIntegrationTests):
@property
@ -19,7 +21,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
@property
def chat_model_params(self) -> dict:
return {"model": "claude-3-5-sonnet-latest"}
return {"model": MODEL}
@property
def supports_image_inputs(self) -> bool:
@ -67,8 +69,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
def invoke_with_cache_creation_input(self, *, stream: bool = False) -> AIMessage:
llm = ChatAnthropic(
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
model=MODEL, # type: ignore[call-arg]
)
with open(REPO_ROOT_DIR / "README.md") as f:
readme = f.read()
@ -96,8 +97,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
llm = ChatAnthropic(
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
model=MODEL, # type: ignore[call-arg]
)
with open(REPO_ROOT_DIR / "README.md") as f:
readme = f.read()

View File

@ -1213,3 +1213,76 @@ def test_cache_control_kwarg() -> None:
],
},
]
def test_streaming_cache_token_reporting() -> None:
"""Test that cache tokens are properly reported in streaming events."""
from unittest.mock import MagicMock
from anthropic.types import MessageDeltaUsage
from langchain_anthropic.chat_models import _make_message_chunk_from_anthropic_event
# Create a mock message_start event
mock_message = MagicMock()
mock_message.model = "claude-3-sonnet-20240229"
mock_message.usage.input_tokens = 100
mock_message.usage.output_tokens = 0
mock_message.usage.cache_read_input_tokens = 25
mock_message.usage.cache_creation_input_tokens = 10
message_start_event = MagicMock()
message_start_event.type = "message_start"
message_start_event.message = mock_message
# Create a mock message_delta event with complete usage info
mock_delta_usage = MessageDeltaUsage(
output_tokens=50,
input_tokens=100,
cache_read_input_tokens=25,
cache_creation_input_tokens=10,
)
mock_delta = MagicMock()
mock_delta.stop_reason = "end_turn"
mock_delta.stop_sequence = None
message_delta_event = MagicMock()
message_delta_event.type = "message_delta"
message_delta_event.usage = mock_delta_usage
message_delta_event.delta = mock_delta
# Test message_start event
start_chunk, _ = _make_message_chunk_from_anthropic_event(
message_start_event,
stream_usage=True,
coerce_content_to_string=True,
block_start_event=None,
)
# Test message_delta event - should contain complete usage metadata (w/ cache)
delta_chunk, _ = _make_message_chunk_from_anthropic_event(
message_delta_event,
stream_usage=True,
coerce_content_to_string=True,
block_start_event=None,
)
# Verify message_delta has complete usage_metadata including cache tokens
assert start_chunk is not None, "message_start should produce a chunk"
assert getattr(start_chunk, "usage_metadata", None) is None, (
"message_start should not have usage_metadata"
)
assert delta_chunk is not None, "message_delta should produce a chunk"
assert delta_chunk.usage_metadata is not None, (
"message_delta should have usage_metadata"
)
assert "input_token_details" in delta_chunk.usage_metadata
input_details = delta_chunk.usage_metadata["input_token_details"]
assert input_details.get("cache_read") == 25
assert input_details.get("cache_creation") == 10
# Verify totals are correct: 100 base + 25 cache_read + 10 cache_creation = 135
assert delta_chunk.usage_metadata["input_tokens"] == 135
assert delta_chunk.usage_metadata["output_tokens"] == 50
assert delta_chunk.usage_metadata["total_tokens"] == 185

View File

@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.9"
resolution-markers = [
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
@ -469,7 +469,7 @@ typing = [
[package.metadata]
requires-dist = [
{ name = "anthropic", specifier = ">=0.60.0,<1" },
{ name = "anthropic", specifier = ">=0.64.0,<1" },
{ name = "langchain-core", editable = "../../core" },
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
]