diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 0fe97e80366..cf6866638ca 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -8,14 +8,16 @@ from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages MODEL_NAME = "claude-3-5-haiku-latest" -async def test_astream() -> None: + +def test_stream() -> None: """Test streaming tokens from Anthropic.""" llm = ChatAnthropicMessages(model_name=MODEL_NAME) # type: ignore[call-arg, call-arg] full: Optional[BaseMessageChunk] = None chunks_with_input_token_counts = 0 chunks_with_output_token_counts = 0 - async for token in llm.astream("I'm Pickle Rick"): + chunks_with_model_name = 0 + for token in llm.stream("I'm Pickle Rick"): assert isinstance(token.content, str) full = token if full is None else full + token assert isinstance(token, AIMessageChunk) @@ -24,12 +26,14 @@ async def test_astream() -> None: chunks_with_input_token_counts += 1 if token.usage_metadata.get("output_tokens"): chunks_with_output_token_counts += 1 + chunks_with_model_name += int("model_name" in token.response_metadata) if chunks_with_input_token_counts != 1 or chunks_with_output_token_counts != 1: raise AssertionError( "Expected exactly one chunk with input or output token counts. " "AIMessageChunk aggregation adds counts. Check that " "this is behaving properly." ) + assert chunks_with_model_name == 1 # check token usage is populated assert isinstance(full, AIMessageChunk) assert full.usage_metadata is not None @@ -42,27 +46,7 @@ async def test_astream() -> None: ) assert "stop_reason" in full.response_metadata assert "stop_sequence" in full.response_metadata - - # Check expected raw API output - async_client = llm._async_client - params: dict = { - "model": MODEL_NAME, - "max_tokens": 1024, - "messages": [{"role": "user", "content": "hi"}], - "temperature": 0.0, - } - stream = await async_client.messages.create(**params, stream=True) - async for event in stream: - if event.type == "message_start": - assert event.message.usage.input_tokens > 1 - # Note: this single output token included in message start event - # does not appear to contribute to overall output token counts. It - # is excluded from the total token count. - assert event.message.usage.output_tokens == 1 - elif event.type == "message_delta": - assert event.usage.output_tokens > 1 - else: - pass + assert "model_name" in full.response_metadata async def test_stream_usage() -> None: