diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 6d4007e8879..9571781731c 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -15,6 +15,7 @@ from io import BytesIO from math import ceil from operator import itemgetter from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Callable, @@ -100,11 +101,13 @@ from langchain_core.utils.pydantic import ( is_basemodel_subclass, ) from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env -from openai.types.responses import Response from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import Self +if TYPE_CHECKING: + from openai.types.responses import Response + logger = logging.getLogger(__name__) # This SSL context is equivelent to the default `verify=True`. @@ -921,11 +924,15 @@ class BaseChatOpenAI(BaseChatModel): if error := response.error: raise ValueError(error) - token_usage = response.usage.model_dump() + token_usage = response.usage.model_dump() if response.usage else {} generation_info = {} for output in response.output: if output.type == "message": - joined = "".join(content.text for content in output.content) + joined = "".join( + content.text + for content in output.content + if content.type == "output_text" + ) usage_metadata = _create_usage_metadata_responses(token_usage) message = AIMessage( content=joined, id=output.id, usage_metadata=usage_metadata diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index bfae40ebfd9..e4643c1acee 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1238,6 +1238,7 @@ def test_web_search() -> None: ) assert isinstance(response, AIMessage) assert response.content + assert response.usage_metadata assert response.usage_metadata["input_tokens"] > 0 assert response.usage_metadata["output_tokens"] > 0 assert response.usage_metadata["total_tokens"] > 0 @@ -1253,6 +1254,7 @@ async def test_web_search_async() -> None: ) assert isinstance(response, AIMessage) assert response.content + assert response.usage_metadata assert response.usage_metadata["input_tokens"] > 0 assert response.usage_metadata["output_tokens"] > 0 assert response.usage_metadata["total_tokens"] > 0