mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-22 10:59:22 +00:00
fix(anthropic): correct input_token
count for streaming (#32591)
* Create usage metadata on [`message_delta`](https://docs.anthropic.com/en/docs/build-with-claude/streaming#event-types) instead of at the beginning. Consequently, token counts are not included during streaming but instead at the end. This allows for accurate reporting of server-side tool usage (important for billing) * Add some clarifying comments * Fix some outstanding Pylance warnings * Remove unnecessary `text` popping in thinking blocks * Also now correctly reports `input_cache_read`/`input_cache_creation` as a result
This commit is contained in:
parent
8042b04da6
commit
8d0fb2d04b
@ -2192,47 +2192,65 @@ def _make_message_chunk_from_anthropic_event(
|
||||
coerce_content_to_string: bool,
|
||||
block_start_event: Optional[anthropic.types.RawMessageStreamEvent] = None,
|
||||
) -> tuple[Optional[AIMessageChunk], Optional[anthropic.types.RawMessageStreamEvent]]:
|
||||
"""Convert Anthropic event to AIMessageChunk.
|
||||
"""Convert Anthropic streaming event to `AIMessageChunk`.
|
||||
|
||||
Args:
|
||||
event: Raw streaming event from Anthropic SDK
|
||||
stream_usage: Whether to include usage metadata in the output chunks.
|
||||
coerce_content_to_string: Whether to convert structured content to plain
|
||||
text strings. When True, only text content is preserved; when False,
|
||||
structured content like tool calls and citations are maintained.
|
||||
block_start_event: Previous content block start event, used for tracking
|
||||
tool use blocks and maintaining context across related events.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- AIMessageChunk: Converted message chunk with appropriate content and
|
||||
metadata, or None if the event doesn't produce a chunk
|
||||
- RawMessageStreamEvent: Updated `block_start_event` for tracking content
|
||||
blocks across sequential events, or None if not applicable
|
||||
|
||||
Note:
|
||||
Not all Anthropic events result in message chunks. Events like internal
|
||||
state changes return None for the message chunk while potentially
|
||||
updating the `block_start_event` for context tracking.
|
||||
|
||||
Note that not all events will result in a message chunk. In these cases
|
||||
we return ``None``.
|
||||
"""
|
||||
message_chunk: Optional[AIMessageChunk] = None
|
||||
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
||||
# Reference: Anthropic SDK streaming implementation
|
||||
# https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
|
||||
|
||||
if event.type == "message_start" and stream_usage:
|
||||
usage_metadata = _create_usage_metadata(event.message.usage)
|
||||
# We pick up a cumulative count of output_tokens at the end of the stream,
|
||||
# so here we zero out to avoid double counting.
|
||||
usage_metadata["total_tokens"] = (
|
||||
usage_metadata["total_tokens"] - usage_metadata["output_tokens"]
|
||||
)
|
||||
usage_metadata["output_tokens"] = 0
|
||||
# Capture model name, but don't include usage_metadata yet
|
||||
# as it will be properly reported in message_delta with complete info
|
||||
if hasattr(event.message, "model"):
|
||||
response_metadata = {"model_name": event.message.model}
|
||||
else:
|
||||
response_metadata = {}
|
||||
|
||||
message_chunk = AIMessageChunk(
|
||||
content="" if coerce_content_to_string else [],
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata=response_metadata,
|
||||
)
|
||||
|
||||
elif (
|
||||
event.type == "content_block_start"
|
||||
and event.content_block is not None
|
||||
and event.content_block.type
|
||||
in (
|
||||
"tool_use",
|
||||
"code_execution_tool_result",
|
||||
"tool_use", # Standard tool usage
|
||||
"code_execution_tool_result", # Built-in code execution results
|
||||
"document",
|
||||
"redacted_thinking",
|
||||
"mcp_tool_use",
|
||||
"mcp_tool_result",
|
||||
"server_tool_use",
|
||||
"web_search_tool_result",
|
||||
"server_tool_use", # Server-side tool usage
|
||||
"web_search_tool_result", # Built-in web search results
|
||||
)
|
||||
):
|
||||
if coerce_content_to_string:
|
||||
warnings.warn("Received unexpected tool content block.", stacklevel=2)
|
||||
|
||||
content_block = event.content_block.model_dump()
|
||||
content_block["index"] = event.index
|
||||
if event.content_block.type == "tool_use":
|
||||
@ -2250,35 +2268,47 @@ def _make_message_chunk_from_anthropic_event(
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
block_start_event = event
|
||||
|
||||
# Process incremental content updates
|
||||
elif event.type == "content_block_delta":
|
||||
# Text and citation deltas (incremental text content)
|
||||
if event.delta.type in ("text_delta", "citations_delta"):
|
||||
if coerce_content_to_string and hasattr(event.delta, "text"):
|
||||
text = event.delta.text
|
||||
text = getattr(event.delta, "text", "")
|
||||
message_chunk = AIMessageChunk(content=text)
|
||||
else:
|
||||
content_block = event.delta.model_dump()
|
||||
content_block["index"] = event.index
|
||||
|
||||
# All citation deltas are part of a text block
|
||||
content_block["type"] = "text"
|
||||
if "citation" in content_block:
|
||||
# Assign citations to a list if present
|
||||
content_block["citations"] = [content_block.pop("citation")]
|
||||
message_chunk = AIMessageChunk(content=[content_block])
|
||||
|
||||
# Reasoning
|
||||
elif (
|
||||
event.delta.type == "thinking_delta"
|
||||
or event.delta.type == "signature_delta"
|
||||
):
|
||||
content_block = event.delta.model_dump()
|
||||
if "text" in content_block and content_block["text"] is None:
|
||||
content_block.pop("text")
|
||||
content_block["index"] = event.index
|
||||
content_block["type"] = "thinking"
|
||||
message_chunk = AIMessageChunk(content=[content_block])
|
||||
|
||||
# Tool input JSON (streaming tool arguments)
|
||||
elif event.delta.type == "input_json_delta":
|
||||
content_block = event.delta.model_dump()
|
||||
content_block["index"] = event.index
|
||||
start_event_block = (
|
||||
getattr(block_start_event, "content_block", None)
|
||||
if block_start_event
|
||||
else None
|
||||
)
|
||||
if (
|
||||
(block_start_event is not None)
|
||||
and hasattr(block_start_event, "content_block")
|
||||
and (block_start_event.content_block.type == "tool_use")
|
||||
start_event_block is not None
|
||||
and getattr(start_event_block, "type", None) == "tool_use"
|
||||
):
|
||||
tool_call_chunk = create_tool_call_chunk(
|
||||
index=event.index,
|
||||
@ -2293,12 +2323,10 @@ def _make_message_chunk_from_anthropic_event(
|
||||
content=[content_block],
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
|
||||
# Process final usage metadata and completion info
|
||||
elif event.type == "message_delta" and stream_usage:
|
||||
usage_metadata = UsageMetadata(
|
||||
input_tokens=0,
|
||||
output_tokens=event.usage.output_tokens,
|
||||
total_tokens=event.usage.output_tokens,
|
||||
)
|
||||
usage_metadata = _create_usage_metadata(event.usage)
|
||||
message_chunk = AIMessageChunk(
|
||||
content="",
|
||||
usage_metadata=usage_metadata,
|
||||
@ -2307,6 +2335,8 @@ def _make_message_chunk_from_anthropic_event(
|
||||
"stop_sequence": event.delta.stop_sequence,
|
||||
},
|
||||
)
|
||||
# Unhandled event types (e.g., `content_block_stop`, `ping` events)
|
||||
# https://docs.anthropic.com/en/docs/build-with-claude/streaming#other-events
|
||||
else:
|
||||
pass
|
||||
|
||||
@ -2319,26 +2349,38 @@ class ChatAnthropicMessages(ChatAnthropic):
|
||||
|
||||
|
||||
def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
|
||||
"""Create LangChain `UsageMetadata` from Anthropic `Usage` data.
|
||||
|
||||
Note: Anthropic's `input_tokens` excludes cached tokens, so we manually add
|
||||
`cache_read` and `cache_creation` tokens to get the true total.
|
||||
|
||||
"""
|
||||
input_token_details: dict = {
|
||||
"cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None),
|
||||
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
|
||||
}
|
||||
# Add (beta) cache TTL information if available
|
||||
|
||||
# Add cache TTL information if provided (5-minute and 1-hour ephemeral cache)
|
||||
cache_creation = getattr(anthropic_usage, "cache_creation", None)
|
||||
cache_creation_keys = ("ephemeral_1h_input_tokens", "ephemeral_5m_input_tokens")
|
||||
|
||||
# Currently just copying over the 5m and 1h keys, but if more are added in the
|
||||
# future we'll need to expand this tuple
|
||||
cache_creation_keys = ("ephemeral_5m_input_tokens", "ephemeral_1h_input_tokens")
|
||||
if cache_creation:
|
||||
if isinstance(cache_creation, BaseModel):
|
||||
cache_creation = cache_creation.model_dump()
|
||||
for k in cache_creation_keys:
|
||||
input_token_details[k] = cache_creation.get(k)
|
||||
|
||||
# Anthropic input_tokens exclude cached token counts.
|
||||
# Calculate total input tokens: Anthropic's `input_tokens` excludes cached tokens,
|
||||
# so we need to add them back to get the true total input token count
|
||||
input_tokens = (
|
||||
(getattr(anthropic_usage, "input_tokens", 0) or 0)
|
||||
+ (input_token_details["cache_read"] or 0)
|
||||
+ (input_token_details["cache_creation"] or 0)
|
||||
(getattr(anthropic_usage, "input_tokens", 0) or 0) # Base input tokens
|
||||
+ (input_token_details["cache_read"] or 0) # Tokens read from cache
|
||||
+ (input_token_details["cache_creation"] or 0) # Tokens used to create cache
|
||||
)
|
||||
output_tokens = getattr(anthropic_usage, "output_tokens", 0) or 0
|
||||
|
||||
return UsageMetadata(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
|
@ -7,7 +7,7 @@ authors = []
|
||||
license = { text = "MIT" }
|
||||
requires-python = ">=3.9"
|
||||
dependencies = [
|
||||
"anthropic<1,>=0.60.0",
|
||||
"anthropic<1,>=0.64.0",
|
||||
"langchain-core<1.0.0,>=0.3.72",
|
||||
"pydantic<3.0.0,>=2.7.4",
|
||||
]
|
||||
|
Binary file not shown.
@ -9,6 +9,8 @@ from langchain_core.outputs import LLMResult
|
||||
from langchain_anthropic import Anthropic
|
||||
from tests.unit_tests._utils import FakeCallbackHandler
|
||||
|
||||
MODEL = "claude-3-7-sonnet-latest"
|
||||
|
||||
|
||||
@pytest.mark.requires("anthropic")
|
||||
def test_anthropic_model_name_param() -> None:
|
||||
@ -24,14 +26,14 @@ def test_anthropic_model_param() -> None:
|
||||
|
||||
def test_anthropic_call() -> None:
|
||||
"""Test valid call to anthropic."""
|
||||
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
|
||||
llm = Anthropic(model=MODEL) # type: ignore[call-arg]
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_anthropic_streaming() -> None:
|
||||
"""Test streaming tokens from anthropic."""
|
||||
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
|
||||
llm = Anthropic(model=MODEL) # type: ignore[call-arg]
|
||||
generator = llm.stream("I'm Pickle Rick")
|
||||
|
||||
assert isinstance(generator, Generator)
|
||||
@ -45,7 +47,7 @@ def test_anthropic_streaming_callback() -> None:
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = Anthropic(
|
||||
model="claude-3-7-sonnet-20250219", # type: ignore[call-arg]
|
||||
model=MODEL, # type: ignore[call-arg]
|
||||
streaming=True,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
@ -56,7 +58,7 @@ def test_anthropic_streaming_callback() -> None:
|
||||
|
||||
async def test_anthropic_async_generate() -> None:
|
||||
"""Test async generate."""
|
||||
llm = Anthropic(model="claude-3-7-sonnet-20250219") # type: ignore[call-arg]
|
||||
llm = Anthropic(model=MODEL) # type: ignore[call-arg]
|
||||
output = await llm.agenerate(["How many toes do dogs have?"])
|
||||
assert isinstance(output, LLMResult)
|
||||
|
||||
@ -66,7 +68,7 @@ async def test_anthropic_async_streaming_callback() -> None:
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
llm = Anthropic(
|
||||
model="claude-3-7-sonnet-20250219", # type: ignore[call-arg]
|
||||
model=MODEL, # type: ignore[call-arg]
|
||||
streaming=True,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
|
@ -11,6 +11,8 @@ from langchain_anthropic import ChatAnthropic
|
||||
|
||||
REPO_ROOT_DIR = Path(__file__).parents[5]
|
||||
|
||||
MODEL = "claude-3-5-haiku-latest"
|
||||
|
||||
|
||||
class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
@ -19,7 +21,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "claude-3-5-sonnet-latest"}
|
||||
return {"model": MODEL}
|
||||
|
||||
@property
|
||||
def supports_image_inputs(self) -> bool:
|
||||
@ -67,8 +69,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
|
||||
def invoke_with_cache_creation_input(self, *, stream: bool = False) -> AIMessage:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
|
||||
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
|
||||
model=MODEL, # type: ignore[call-arg]
|
||||
)
|
||||
with open(REPO_ROOT_DIR / "README.md") as f:
|
||||
readme = f.read()
|
||||
@ -96,8 +97,7 @@ class TestAnthropicStandard(ChatModelIntegrationTests):
|
||||
|
||||
def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
|
||||
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
|
||||
model=MODEL, # type: ignore[call-arg]
|
||||
)
|
||||
with open(REPO_ROOT_DIR / "README.md") as f:
|
||||
readme = f.read()
|
||||
|
@ -1213,3 +1213,76 @@ def test_cache_control_kwarg() -> None:
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_streaming_cache_token_reporting() -> None:
|
||||
"""Test that cache tokens are properly reported in streaming events."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from anthropic.types import MessageDeltaUsage
|
||||
|
||||
from langchain_anthropic.chat_models import _make_message_chunk_from_anthropic_event
|
||||
|
||||
# Create a mock message_start event
|
||||
mock_message = MagicMock()
|
||||
mock_message.model = "claude-3-sonnet-20240229"
|
||||
mock_message.usage.input_tokens = 100
|
||||
mock_message.usage.output_tokens = 0
|
||||
mock_message.usage.cache_read_input_tokens = 25
|
||||
mock_message.usage.cache_creation_input_tokens = 10
|
||||
|
||||
message_start_event = MagicMock()
|
||||
message_start_event.type = "message_start"
|
||||
message_start_event.message = mock_message
|
||||
|
||||
# Create a mock message_delta event with complete usage info
|
||||
mock_delta_usage = MessageDeltaUsage(
|
||||
output_tokens=50,
|
||||
input_tokens=100,
|
||||
cache_read_input_tokens=25,
|
||||
cache_creation_input_tokens=10,
|
||||
)
|
||||
|
||||
mock_delta = MagicMock()
|
||||
mock_delta.stop_reason = "end_turn"
|
||||
mock_delta.stop_sequence = None
|
||||
|
||||
message_delta_event = MagicMock()
|
||||
message_delta_event.type = "message_delta"
|
||||
message_delta_event.usage = mock_delta_usage
|
||||
message_delta_event.delta = mock_delta
|
||||
|
||||
# Test message_start event
|
||||
start_chunk, _ = _make_message_chunk_from_anthropic_event(
|
||||
message_start_event,
|
||||
stream_usage=True,
|
||||
coerce_content_to_string=True,
|
||||
block_start_event=None,
|
||||
)
|
||||
|
||||
# Test message_delta event - should contain complete usage metadata (w/ cache)
|
||||
delta_chunk, _ = _make_message_chunk_from_anthropic_event(
|
||||
message_delta_event,
|
||||
stream_usage=True,
|
||||
coerce_content_to_string=True,
|
||||
block_start_event=None,
|
||||
)
|
||||
|
||||
# Verify message_delta has complete usage_metadata including cache tokens
|
||||
assert start_chunk is not None, "message_start should produce a chunk"
|
||||
assert getattr(start_chunk, "usage_metadata", None) is None, (
|
||||
"message_start should not have usage_metadata"
|
||||
)
|
||||
assert delta_chunk is not None, "message_delta should produce a chunk"
|
||||
assert delta_chunk.usage_metadata is not None, (
|
||||
"message_delta should have usage_metadata"
|
||||
)
|
||||
assert "input_token_details" in delta_chunk.usage_metadata
|
||||
input_details = delta_chunk.usage_metadata["input_token_details"]
|
||||
assert input_details.get("cache_read") == 25
|
||||
assert input_details.get("cache_creation") == 10
|
||||
|
||||
# Verify totals are correct: 100 base + 25 cache_read + 10 cache_creation = 135
|
||||
assert delta_chunk.usage_metadata["input_tokens"] == 135
|
||||
assert delta_chunk.usage_metadata["output_tokens"] == 50
|
||||
assert delta_chunk.usage_metadata["total_tokens"] == 185
|
||||
|
@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.9"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.13' and platform_python_implementation == 'PyPy'",
|
||||
@ -469,7 +469,7 @@ typing = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "anthropic", specifier = ">=0.60.0,<1" },
|
||||
{ name = "anthropic", specifier = ">=0.64.0,<1" },
|
||||
{ name = "langchain-core", editable = "../../core" },
|
||||
{ name = "pydantic", specifier = ">=2.7.4,<3.0.0" },
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user