diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 4e7faf2c0ed..ce9c25cb6cc 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -456,6 +456,12 @@ class BaseChatOpenAI(BaseChatModel): ) """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or None.""" + stream_usage: bool = False + """Whether to include usage metadata in streaming output. If True, an additional + message chunk will be generated during the stream including usage metadata. + + .. versionadded:: 0.3.9 + """ max_retries: Optional[int] = None """Maximum number of retries to make when generating.""" presence_penalty: Optional[float] = None @@ -811,14 +817,38 @@ class BaseChatOpenAI(BaseChatModel): is_first_chunk = False yield generation_chunk + def _should_stream_usage( + self, stream_usage: Optional[bool] = None, **kwargs: Any + ) -> bool: + """Determine whether to include usage metadata in streaming output. + + For backwards compatibility, we check for `stream_options` passed + explicitly to kwargs or in the model_kwargs and override self.stream_usage. + """ + stream_usage_sources = [ # order of precedence + stream_usage, + kwargs.get("stream_options", {}).get("include_usage"), + self.model_kwargs.get("stream_options", {}).get("include_usage"), + self.stream_usage, + ] + for source in stream_usage_sources: + if isinstance(source, bool): + return source + return self.stream_usage + def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: kwargs["stream"] = True + stream_usage = self._should_stream_usage(stream_usage, **kwargs) + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} @@ -1005,9 +1035,14 @@ class BaseChatOpenAI(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: kwargs["stream"] = True + stream_usage = self._should_stream_usage(stream_usage, **kwargs) + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} @@ -2202,11 +2237,6 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] """ # noqa: E501 - stream_usage: bool = False - """Whether to include usage metadata in streaming output. If True, additional - message chunks will be generated during the stream including usage metadata. - """ - max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens") """Maximum number of tokens to generate.""" @@ -2268,55 +2298,21 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] message["role"] = "developer" return payload - def _should_stream_usage( - self, stream_usage: Optional[bool] = None, **kwargs: Any - ) -> bool: - """Determine whether to include usage metadata in streaming output. - - For backwards compatibility, we check for `stream_options` passed - explicitly to kwargs or in the model_kwargs and override self.stream_usage. - """ - stream_usage_sources = [ # order of preference - stream_usage, - kwargs.get("stream_options", {}).get("include_usage"), - self.model_kwargs.get("stream_options", {}).get("include_usage"), - self.stream_usage, - ] - for source in stream_usage_sources: - if isinstance(source, bool): - return source - return self.stream_usage - - def _stream( - self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any - ) -> Iterator[ChatGenerationChunk]: - """Set default stream_options.""" + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: + """Route to Chat Completions or Responses API.""" if self._use_responses_api({**kwargs, **self.model_kwargs}): return super()._stream_responses(*args, **kwargs) else: - stream_usage = self._should_stream_usage(stream_usage, **kwargs) - # Note: stream_options is not a valid parameter for Azure OpenAI. - # To support users proxying Azure through ChatOpenAI, here we only specify - # stream_options if include_usage is set to True. - # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new - # for release notes. - if stream_usage: - kwargs["stream_options"] = {"include_usage": stream_usage} - return super()._stream(*args, **kwargs) async def _astream( - self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any + self, *args: Any, **kwargs: Any ) -> AsyncIterator[ChatGenerationChunk]: - """Set default stream_options.""" + """Route to Chat Completions or Responses API.""" if self._use_responses_api({**kwargs, **self.model_kwargs}): async for chunk in super()._astream_responses(*args, **kwargs): yield chunk else: - stream_usage = self._should_stream_usage(stream_usage, **kwargs) - if stream_usage: - kwargs["stream_options"] = {"include_usage": stream_usage} - async for chunk in super()._astream(*args, **kwargs): yield chunk diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py index f5820794bb3..b87be33b30e 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py @@ -3,7 +3,6 @@ import os from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_tests.integration_tests import ChatModelIntegrationTests @@ -25,6 +24,7 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests): "model": "gpt-4o-mini", "openai_api_version": OPENAI_API_VERSION, "azure_endpoint": OPENAI_API_BASE, + "stream_usage": True, } @property @@ -35,10 +35,6 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests): def supports_json_mode(self) -> bool: return True - @pytest.mark.xfail(reason="Not yet supported.") - def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: - super().test_usage_metadata_streaming(model) - class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests): """Test a legacy model.""" @@ -53,12 +49,9 @@ class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests): "deployment_name": os.environ["AZURE_OPENAI_LEGACY_CHAT_DEPLOYMENT_NAME"], "openai_api_version": OPENAI_API_VERSION, "azure_endpoint": OPENAI_API_BASE, + "stream_usage": True, } @property def structured_output_kwargs(self) -> dict: return {"method": "function_calling"} - - @pytest.mark.xfail(reason="Not yet supported.") - def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: - super().test_usage_metadata_streaming(model)