From f5f6e869cd60672ca87a3cdc694e6fe6d5d723b5 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Thu, 8 May 2025 21:05:49 -0400 Subject: [PATCH] revert --- .../anthropic/langchain_anthropic/chat_models.py | 12 +----------- .../tests/integration_tests/test_chat_models.py | 7 ++----- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 661ef286c2b..25d52e21d3c 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1761,12 +1761,6 @@ def _make_message_chunk_from_anthropic_event( # See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501 if event.type == "message_start" and stream_usage: usage_metadata = _create_usage_metadata(event.message.usage) - # We pick up a cumulative count of output_tokens at the end of the stream, - # so here we zero out to avoid double counting. - usage_metadata["total_tokens"] = ( - usage_metadata["total_tokens"] - usage_metadata["output_tokens"] - ) - usage_metadata["output_tokens"] = 0 if hasattr(event.message, "model"): response_metadata = {"model_name": event.message.model} else: @@ -1840,11 +1834,7 @@ def _make_message_chunk_from_anthropic_event( tool_call_chunks=[tool_call_chunk], # type: ignore ) elif event.type == "message_delta" and stream_usage: - usage_metadata = UsageMetadata( - input_tokens=0, - output_tokens=event.usage.output_tokens, - total_tokens=event.usage.output_tokens, - ) + usage_metadata = _create_usage_metadata(event.usage) message_chunk = AIMessageChunk( content="", usage_metadata=usage_metadata, diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index a2563cc480d..5c41c3f878e 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -46,7 +46,7 @@ def test_stream() -> None: if token.usage_metadata is not None: if token.usage_metadata.get("input_tokens"): chunks_with_input_token_counts += 1 - if token.usage_metadata.get("output_tokens"): + elif token.usage_metadata.get("output_tokens"): chunks_with_output_token_counts += 1 chunks_with_model_name += int("model_name" in token.response_metadata) if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: @@ -85,7 +85,7 @@ async def test_astream() -> None: if token.usage_metadata is not None: if token.usage_metadata.get("input_tokens"): chunks_with_input_token_counts += 1 - if token.usage_metadata.get("output_tokens"): + elif token.usage_metadata.get("output_tokens"): chunks_with_output_token_counts += 1 if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: raise AssertionError( @@ -134,9 +134,6 @@ async def test_stream_usage() -> None: async for token in model.astream("hi"): assert isinstance(token, AIMessageChunk) assert token.usage_metadata is None - - -async def test_stream_usage_override() -> None: # check we override with kwarg model = ChatAnthropic(model_name=MODEL_NAME) # type: ignore[call-arg] assert model.stream_usage