This commit is contained in:
Chester Curme
2025-05-08 21:05:49 -04:00
parent e9e597be8e
commit f5f6e869cd
2 changed files with 3 additions and 16 deletions

View File

@@ -1761,12 +1761,6 @@ def _make_message_chunk_from_anthropic_event(
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage: if event.type == "message_start" and stream_usage:
usage_metadata = _create_usage_metadata(event.message.usage) usage_metadata = _create_usage_metadata(event.message.usage)
# We pick up a cumulative count of output_tokens at the end of the stream,
# so here we zero out to avoid double counting.
usage_metadata["total_tokens"] = (
usage_metadata["total_tokens"] - usage_metadata["output_tokens"]
)
usage_metadata["output_tokens"] = 0
if hasattr(event.message, "model"): if hasattr(event.message, "model"):
response_metadata = {"model_name": event.message.model} response_metadata = {"model_name": event.message.model}
else: else:
@@ -1840,11 +1834,7 @@ def _make_message_chunk_from_anthropic_event(
tool_call_chunks=[tool_call_chunk], # type: ignore tool_call_chunks=[tool_call_chunk], # type: ignore
) )
elif event.type == "message_delta" and stream_usage: elif event.type == "message_delta" and stream_usage:
usage_metadata = UsageMetadata( usage_metadata = _create_usage_metadata(event.usage)
input_tokens=0,
output_tokens=event.usage.output_tokens,
total_tokens=event.usage.output_tokens,
)
message_chunk = AIMessageChunk( message_chunk = AIMessageChunk(
content="", content="",
usage_metadata=usage_metadata, usage_metadata=usage_metadata,

View File

@@ -46,7 +46,7 @@ def test_stream() -> None:
if token.usage_metadata is not None: if token.usage_metadata is not None:
if token.usage_metadata.get("input_tokens"): if token.usage_metadata.get("input_tokens"):
chunks_with_input_token_counts += 1 chunks_with_input_token_counts += 1
if token.usage_metadata.get("output_tokens"): elif token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1 chunks_with_output_token_counts += 1
chunks_with_model_name += int("model_name" in token.response_metadata) chunks_with_model_name += int("model_name" in token.response_metadata)
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
@@ -85,7 +85,7 @@ async def test_astream() -> None:
if token.usage_metadata is not None: if token.usage_metadata is not None:
if token.usage_metadata.get("input_tokens"): if token.usage_metadata.get("input_tokens"):
chunks_with_input_token_counts += 1 chunks_with_input_token_counts += 1
if token.usage_metadata.get("output_tokens"): elif token.usage_metadata.get("output_tokens"):
chunks_with_output_token_counts += 1 chunks_with_output_token_counts += 1
if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1:
raise AssertionError( raise AssertionError(
@@ -134,9 +134,6 @@ async def test_stream_usage() -> None:
async for token in model.astream("hi"): async for token in model.astream("hi"):
assert isinstance(token, AIMessageChunk) assert isinstance(token, AIMessageChunk)
assert token.usage_metadata is None assert token.usage_metadata is None
async def test_stream_usage_override() -> None:
# check we override with kwarg # check we override with kwarg
model = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg] model = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg]
assert model.stream_usage assert model.stream_usage