diff --git a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py index b5e0a9a0bd4..5c0c8c6628c 100644 --- a/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/integration_tests/chat_models.py @@ -1,6 +1,6 @@ import base64 import json -from typing import List, Optional +from typing import List, Optional, cast import httpx import pytest @@ -209,10 +209,21 @@ class ChatModelIntegrationTests(ChatModelTests): def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: if not self.returns_usage_metadata: pytest.skip("Not implemented.") - full: Optional[BaseMessageChunk] = None - for chunk in model.stream("Hello"): + full: Optional[AIMessageChunk] = None + for chunk in model.stream("Write me 2 haikus. Only include the haikus."): assert isinstance(chunk, AIMessageChunk) - full = chunk if full is None else full + chunk + # only one chunk is allowed to set usage_metadata.input_tokens + # if multiple do, it's likely a bug that will result in overcounting + # input tokens + if full and full.usage_metadata and full.usage_metadata["input_tokens"]: + assert ( + not chunk.usage_metadata or not chunk.usage_metadata["input_tokens"] + ), ( + "Only one chunk should set input_tokens," + " the rest should be 0 or None" + ) + full = chunk if full is None else cast(AIMessageChunk, full + chunk) + assert isinstance(full, AIMessageChunk) assert full.usage_metadata is not None assert isinstance(full.usage_metadata["input_tokens"], int)