mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-09 09:08:40 +00:00
openai[patch]: support streaming token counts in AzureChatOpenAI (#30494)
When OpenAI originally released `stream_options` to enable token usage
during streaming, it was not supported in AzureOpenAI. It is now
supported.
Like the [OpenAI
SDK](f66d2e6fdc/src/openai/resources/completions.py (L68)
),
ChatOpenAI does not return usage metadata during streaming by default
(which adds an extra chunk to the stream). The OpenAI SDK requires users
to pass `stream_options={"include_usage": True}`. ChatOpenAI implements
a convenience argument `stream_usage: Optional[bool]`, and an attribute
`stream_usage: bool = False`.
Here we extend this to AzureChatOpenAI by moving the `stream_usage`
attribute and `stream_usage` kwarg (on `_(a)stream`) from ChatOpenAI to
BaseChatOpenAI.
---
Additional consideration: we must be sensitive to the number of users
using BaseChatOpenAI to interact with other APIs that do not support the
`stream_options` parameter.
Suppose OpenAI in the future updates the default behavior to stream
token usage. Currently, BaseChatOpenAI only passes `stream_options` if
`stream_usage` is True, so there would be no way to disable this new
default behavior.
To address this, we could update the `stream_usage` attribute to
`Optional[bool] = None`, but this is technically a breaking change (as
currently values of False are not passed to the client). IMO: if / when
this change happens, we could accompany it with this update in a minor
bump.
---
Related previous PRs:
- https://github.com/langchain-ai/langchain/pull/22628
- https://github.com/langchain-ai/langchain/pull/22854
- https://github.com/langchain-ai/langchain/pull/23552
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
56629ed87b
commit
8119a7bc5c
libs/partners/openai
@ -456,6 +456,12 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
)
|
||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
stream_usage: bool = False
|
||||
"""Whether to include usage metadata in streaming output. If True, an additional
|
||||
message chunk will be generated during the stream including usage metadata.
|
||||
|
||||
.. versionadded:: 0.3.9
|
||||
"""
|
||||
max_retries: Optional[int] = None
|
||||
"""Maximum number of retries to make when generating."""
|
||||
presence_penalty: Optional[float] = None
|
||||
@ -811,14 +817,38 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
is_first_chunk = False
|
||||
yield generation_chunk
|
||||
|
||||
def _should_stream_usage(
|
||||
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> bool:
|
||||
"""Determine whether to include usage metadata in streaming output.
|
||||
|
||||
For backwards compatibility, we check for `stream_options` passed
|
||||
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
|
||||
"""
|
||||
stream_usage_sources = [ # order of precedence
|
||||
stream_usage,
|
||||
kwargs.get("stream_options", {}).get("include_usage"),
|
||||
self.model_kwargs.get("stream_options", {}).get("include_usage"),
|
||||
self.stream_usage,
|
||||
]
|
||||
for source in stream_usage_sources:
|
||||
if isinstance(source, bool):
|
||||
return source
|
||||
return self.stream_usage
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
stream_usage: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
@ -1005,9 +1035,14 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
*,
|
||||
stream_usage: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
kwargs["stream"] = True
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
base_generation_info = {}
|
||||
@ -2202,11 +2237,6 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
stream_usage: bool = False
|
||||
"""Whether to include usage metadata in streaming output. If True, additional
|
||||
message chunks will be generated during the stream including usage metadata.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
@ -2268,55 +2298,21 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
||||
message["role"] = "developer"
|
||||
return payload
|
||||
|
||||
def _should_stream_usage(
|
||||
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> bool:
|
||||
"""Determine whether to include usage metadata in streaming output.
|
||||
|
||||
For backwards compatibility, we check for `stream_options` passed
|
||||
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
|
||||
"""
|
||||
stream_usage_sources = [ # order of preference
|
||||
stream_usage,
|
||||
kwargs.get("stream_options", {}).get("include_usage"),
|
||||
self.model_kwargs.get("stream_options", {}).get("include_usage"),
|
||||
self.stream_usage,
|
||||
]
|
||||
for source in stream_usage_sources:
|
||||
if isinstance(source, bool):
|
||||
return source
|
||||
return self.stream_usage
|
||||
|
||||
def _stream(
|
||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||
"""Route to Chat Completions or Responses API."""
|
||||
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||
return super()._stream_responses(*args, **kwargs)
|
||||
else:
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
# Note: stream_options is not a valid parameter for Azure OpenAI.
|
||||
# To support users proxying Azure through ChatOpenAI, here we only specify
|
||||
# stream_options if include_usage is set to True.
|
||||
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
|
||||
# for release notes.
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
|
||||
return super()._stream(*args, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
"""Route to Chat Completions or Responses API."""
|
||||
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||
yield chunk
|
||||
else:
|
||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||
if stream_usage:
|
||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||
|
||||
async for chunk in super()._astream(*args, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
import os
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
@ -25,6 +24,7 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests):
|
||||
"model": "gpt-4o-mini",
|
||||
"openai_api_version": OPENAI_API_VERSION,
|
||||
"azure_endpoint": OPENAI_API_BASE,
|
||||
"stream_usage": True,
|
||||
}
|
||||
|
||||
@property
|
||||
@ -35,10 +35,6 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests):
|
||||
def supports_json_mode(self) -> bool:
|
||||
return True
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
||||
|
||||
class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
||||
"""Test a legacy model."""
|
||||
@ -53,12 +49,9 @@ class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
||||
"deployment_name": os.environ["AZURE_OPENAI_LEGACY_CHAT_DEPLOYMENT_NAME"],
|
||||
"openai_api_version": OPENAI_API_VERSION,
|
||||
"azure_endpoint": OPENAI_API_BASE,
|
||||
"stream_usage": True,
|
||||
}
|
||||
|
||||
@property
|
||||
def structured_output_kwargs(self) -> dict:
|
||||
return {"method": "function_calling"}
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
Loading…
Reference in New Issue
Block a user