From 8119a7bc5c6379064eecf2890539fab90e68d8f7 Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 26 Mar 2025 15:16:37 -0400 Subject: [PATCH] openai[patch]: support streaming token counts in AzureChatOpenAI (#30494) When OpenAI originally released `stream_options` to enable token usage during streaming, it was not supported in AzureOpenAI. It is now supported. Like the [OpenAI SDK](https://github.com/openai/openai-python/blob/f66d2e6fdc51c4528c99bb25a8fbca6f9b9b872d/src/openai/resources/completions.py#L68), ChatOpenAI does not return usage metadata during streaming by default (which adds an extra chunk to the stream). The OpenAI SDK requires users to pass `stream_options={"include_usage": True}`. ChatOpenAI implements a convenience argument `stream_usage: Optional[bool]`, and an attribute `stream_usage: bool = False`. Here we extend this to AzureChatOpenAI by moving the `stream_usage` attribute and `stream_usage` kwarg (on `_(a)stream`) from ChatOpenAI to BaseChatOpenAI. --- Additional consideration: we must be sensitive to the number of users using BaseChatOpenAI to interact with other APIs that do not support the `stream_options` parameter. Suppose OpenAI in the future updates the default behavior to stream token usage. Currently, BaseChatOpenAI only passes `stream_options` if `stream_usage` is True, so there would be no way to disable this new default behavior. To address this, we could update the `stream_usage` attribute to `Optional[bool] = None`, but this is technically a breaking change (as currently values of False are not passed to the client). IMO: if / when this change happens, we could accompany it with this update in a minor bump. --- Related previous PRs: - https://github.com/langchain-ai/langchain/pull/22628 - https://github.com/langchain-ai/langchain/pull/22854 - https://github.com/langchain-ai/langchain/pull/23552 --------- Co-authored-by: Eugene Yurtsev --- .../langchain_openai/chat_models/base.py | 82 +++++++++---------- .../chat_models/test_azure_standard.py | 11 +-- 2 files changed, 41 insertions(+), 52 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 4e7faf2c0ed..ce9c25cb6cc 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -456,6 +456,12 @@ class BaseChatOpenAI(BaseChatModel): ) """Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or None.""" + stream_usage: bool = False + """Whether to include usage metadata in streaming output. If True, an additional + message chunk will be generated during the stream including usage metadata. + + .. versionadded:: 0.3.9 + """ max_retries: Optional[int] = None """Maximum number of retries to make when generating.""" presence_penalty: Optional[float] = None @@ -811,14 +817,38 @@ class BaseChatOpenAI(BaseChatModel): is_first_chunk = False yield generation_chunk + def _should_stream_usage( + self, stream_usage: Optional[bool] = None, **kwargs: Any + ) -> bool: + """Determine whether to include usage metadata in streaming output. + + For backwards compatibility, we check for `stream_options` passed + explicitly to kwargs or in the model_kwargs and override self.stream_usage. + """ + stream_usage_sources = [ # order of precedence + stream_usage, + kwargs.get("stream_options", {}).get("include_usage"), + self.model_kwargs.get("stream_options", {}).get("include_usage"), + self.stream_usage, + ] + for source in stream_usage_sources: + if isinstance(source, bool): + return source + return self.stream_usage + def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: kwargs["stream"] = True + stream_usage = self._should_stream_usage(stream_usage, **kwargs) + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} @@ -1005,9 +1035,14 @@ class BaseChatOpenAI(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + *, + stream_usage: Optional[bool] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: kwargs["stream"] = True + stream_usage = self._should_stream_usage(stream_usage, **kwargs) + if stream_usage: + kwargs["stream_options"] = {"include_usage": stream_usage} payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} @@ -2202,11 +2237,6 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] """ # noqa: E501 - stream_usage: bool = False - """Whether to include usage metadata in streaming output. If True, additional - message chunks will be generated during the stream including usage metadata. - """ - max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens") """Maximum number of tokens to generate.""" @@ -2268,55 +2298,21 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] message["role"] = "developer" return payload - def _should_stream_usage( - self, stream_usage: Optional[bool] = None, **kwargs: Any - ) -> bool: - """Determine whether to include usage metadata in streaming output. - - For backwards compatibility, we check for `stream_options` passed - explicitly to kwargs or in the model_kwargs and override self.stream_usage. - """ - stream_usage_sources = [ # order of preference - stream_usage, - kwargs.get("stream_options", {}).get("include_usage"), - self.model_kwargs.get("stream_options", {}).get("include_usage"), - self.stream_usage, - ] - for source in stream_usage_sources: - if isinstance(source, bool): - return source - return self.stream_usage - - def _stream( - self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any - ) -> Iterator[ChatGenerationChunk]: - """Set default stream_options.""" + def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: + """Route to Chat Completions or Responses API.""" if self._use_responses_api({**kwargs, **self.model_kwargs}): return super()._stream_responses(*args, **kwargs) else: - stream_usage = self._should_stream_usage(stream_usage, **kwargs) - # Note: stream_options is not a valid parameter for Azure OpenAI. - # To support users proxying Azure through ChatOpenAI, here we only specify - # stream_options if include_usage is set to True. - # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new - # for release notes. - if stream_usage: - kwargs["stream_options"] = {"include_usage": stream_usage} - return super()._stream(*args, **kwargs) async def _astream( - self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any + self, *args: Any, **kwargs: Any ) -> AsyncIterator[ChatGenerationChunk]: - """Set default stream_options.""" + """Route to Chat Completions or Responses API.""" if self._use_responses_api({**kwargs, **self.model_kwargs}): async for chunk in super()._astream_responses(*args, **kwargs): yield chunk else: - stream_usage = self._should_stream_usage(stream_usage, **kwargs) - if stream_usage: - kwargs["stream_options"] = {"include_usage": stream_usage} - async for chunk in super()._astream(*args, **kwargs): yield chunk diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py index f5820794bb3..b87be33b30e 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure_standard.py @@ -3,7 +3,6 @@ import os from typing import Type -import pytest from langchain_core.language_models import BaseChatModel from langchain_tests.integration_tests import ChatModelIntegrationTests @@ -25,6 +24,7 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests): "model": "gpt-4o-mini", "openai_api_version": OPENAI_API_VERSION, "azure_endpoint": OPENAI_API_BASE, + "stream_usage": True, } @property @@ -35,10 +35,6 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests): def supports_json_mode(self) -> bool: return True - @pytest.mark.xfail(reason="Not yet supported.") - def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: - super().test_usage_metadata_streaming(model) - class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests): """Test a legacy model.""" @@ -53,12 +49,9 @@ class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests): "deployment_name": os.environ["AZURE_OPENAI_LEGACY_CHAT_DEPLOYMENT_NAME"], "openai_api_version": OPENAI_API_VERSION, "azure_endpoint": OPENAI_API_BASE, + "stream_usage": True, } @property def structured_output_kwargs(self) -> dict: return {"method": "function_calling"} - - @pytest.mark.xfail(reason="Not yet supported.") - def test_usage_metadata_streaming(self, model: BaseChatModel) -> None: - super().test_usage_metadata_streaming(model)