1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-09 09:08:40 +00:00

openai[patch]: support streaming token counts in AzureChatOpenAI ()

When OpenAI originally released `stream_options` to enable token usage
during streaming, it was not supported in AzureOpenAI. It is now
supported.

Like the [OpenAI
SDK](f66d2e6fdc/src/openai/resources/completions.py (L68)),
ChatOpenAI does not return usage metadata during streaming by default
(which adds an extra chunk to the stream). The OpenAI SDK requires users
to pass `stream_options={"include_usage": True}`. ChatOpenAI implements
a convenience argument `stream_usage: Optional[bool]`, and an attribute
`stream_usage: bool = False`.

Here we extend this to AzureChatOpenAI by moving the `stream_usage`
attribute and `stream_usage` kwarg (on `_(a)stream`) from ChatOpenAI to
BaseChatOpenAI.

---

Additional consideration: we must be sensitive to the number of users
using BaseChatOpenAI to interact with other APIs that do not support the
`stream_options` parameter.

Suppose OpenAI in the future updates the default behavior to stream
token usage. Currently, BaseChatOpenAI only passes `stream_options` if
`stream_usage` is True, so there would be no way to disable this new
default behavior.

To address this, we could update the `stream_usage` attribute to
`Optional[bool] = None`, but this is technically a breaking change (as
currently values of False are not passed to the client). IMO: if / when
this change happens, we could accompany it with this update in a minor
bump.

--- 

Related previous PRs:
- https://github.com/langchain-ai/langchain/pull/22628
- https://github.com/langchain-ai/langchain/pull/22854
- https://github.com/langchain-ai/langchain/pull/23552

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
ccurme 2025-03-26 15:16:37 -04:00 committed by GitHub
parent 56629ed87b
commit 8119a7bc5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 52 deletions
libs/partners/openai
langchain_openai/chat_models
tests/integration_tests/chat_models

View File

@ -456,6 +456,12 @@ class BaseChatOpenAI(BaseChatModel):
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
stream_usage: bool = False
"""Whether to include usage metadata in streaming output. If True, an additional
message chunk will be generated during the stream including usage metadata.
.. versionadded:: 0.3.9
"""
max_retries: Optional[int] = None
"""Maximum number of retries to make when generating."""
presence_penalty: Optional[float] = None
@ -811,14 +817,38 @@ class BaseChatOpenAI(BaseChatModel):
is_first_chunk = False
yield generation_chunk
def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
) -> bool:
"""Determine whether to include usage metadata in streaming output.
For backwards compatibility, we check for `stream_options` passed
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
"""
stream_usage_sources = [ # order of precedence
stream_usage,
kwargs.get("stream_options", {}).get("include_usage"),
self.model_kwargs.get("stream_options", {}).get("include_usage"),
self.stream_usage,
]
for source in stream_usage_sources:
if isinstance(source, bool):
return source
return self.stream_usage
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
kwargs["stream"] = True
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
@ -1005,9 +1035,14 @@ class BaseChatOpenAI(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
*,
stream_usage: Optional[bool] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
kwargs["stream"] = True
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
@ -2202,11 +2237,6 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
""" # noqa: E501
stream_usage: bool = False
"""Whether to include usage metadata in streaming output. If True, additional
message chunks will be generated during the stream including usage metadata.
"""
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
"""Maximum number of tokens to generate."""
@ -2268,55 +2298,21 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
message["role"] = "developer"
return payload
def _should_stream_usage(
self, stream_usage: Optional[bool] = None, **kwargs: Any
) -> bool:
"""Determine whether to include usage metadata in streaming output.
For backwards compatibility, we check for `stream_options` passed
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
"""
stream_usage_sources = [ # order of preference
stream_usage,
kwargs.get("stream_options", {}).get("include_usage"),
self.model_kwargs.get("stream_options", {}).get("include_usage"),
self.stream_usage,
]
for source in stream_usage_sources:
if isinstance(source, bool):
return source
return self.stream_usage
def _stream(
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
"""Route to Chat Completions or Responses API."""
if self._use_responses_api({**kwargs, **self.model_kwargs}):
return super()._stream_responses(*args, **kwargs)
else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
# Note: stream_options is not a valid parameter for Azure OpenAI.
# To support users proxying Azure through ChatOpenAI, here we only specify
# stream_options if include_usage is set to True.
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
# for release notes.
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
return super()._stream(*args, **kwargs)
async def _astream(
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
self, *args: Any, **kwargs: Any
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
"""Route to Chat Completions or Responses API."""
if self._use_responses_api({**kwargs, **self.model_kwargs}):
async for chunk in super()._astream_responses(*args, **kwargs):
yield chunk
else:
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
if stream_usage:
kwargs["stream_options"] = {"include_usage": stream_usage}
async for chunk in super()._astream(*args, **kwargs):
yield chunk

View File

@ -3,7 +3,6 @@
import os
from typing import Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_tests.integration_tests import ChatModelIntegrationTests
@ -25,6 +24,7 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests):
"model": "gpt-4o-mini",
"openai_api_version": OPENAI_API_VERSION,
"azure_endpoint": OPENAI_API_BASE,
"stream_usage": True,
}
@property
@ -35,10 +35,6 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests):
def supports_json_mode(self) -> bool:
return True
@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
super().test_usage_metadata_streaming(model)
class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
"""Test a legacy model."""
@ -53,12 +49,9 @@ class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
"deployment_name": os.environ["AZURE_OPENAI_LEGACY_CHAT_DEPLOYMENT_NAME"],
"openai_api_version": OPENAI_API_VERSION,
"azure_endpoint": OPENAI_API_BASE,
"stream_usage": True,
}
@property
def structured_output_kwargs(self) -> dict:
return {"method": "function_calling"}
@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
super().test_usage_metadata_streaming(model)