mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 08:32:32 +00:00
openai[patch]: support streaming token counts in AzureChatOpenAI (#30494)
When OpenAI originally released `stream_options` to enable token usage
during streaming, it was not supported in AzureOpenAI. It is now
supported.
Like the [OpenAI
SDK](f66d2e6fdc/src/openai/resources/completions.py (L68)
),
ChatOpenAI does not return usage metadata during streaming by default
(which adds an extra chunk to the stream). The OpenAI SDK requires users
to pass `stream_options={"include_usage": True}`. ChatOpenAI implements
a convenience argument `stream_usage: Optional[bool]`, and an attribute
`stream_usage: bool = False`.
Here we extend this to AzureChatOpenAI by moving the `stream_usage`
attribute and `stream_usage` kwarg (on `_(a)stream`) from ChatOpenAI to
BaseChatOpenAI.
---
Additional consideration: we must be sensitive to the number of users
using BaseChatOpenAI to interact with other APIs that do not support the
`stream_options` parameter.
Suppose OpenAI in the future updates the default behavior to stream
token usage. Currently, BaseChatOpenAI only passes `stream_options` if
`stream_usage` is True, so there would be no way to disable this new
default behavior.
To address this, we could update the `stream_usage` attribute to
`Optional[bool] = None`, but this is technically a breaking change (as
currently values of False are not passed to the client). IMO: if / when
this change happens, we could accompany it with this update in a minor
bump.
---
Related previous PRs:
- https://github.com/langchain-ai/langchain/pull/22628
- https://github.com/langchain-ai/langchain/pull/22854
- https://github.com/langchain-ai/langchain/pull/23552
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
56629ed87b
commit
8119a7bc5c
@ -456,6 +456,12 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
|
||||||
None."""
|
None."""
|
||||||
|
stream_usage: bool = False
|
||||||
|
"""Whether to include usage metadata in streaming output. If True, an additional
|
||||||
|
message chunk will be generated during the stream including usage metadata.
|
||||||
|
|
||||||
|
.. versionadded:: 0.3.9
|
||||||
|
"""
|
||||||
max_retries: Optional[int] = None
|
max_retries: Optional[int] = None
|
||||||
"""Maximum number of retries to make when generating."""
|
"""Maximum number of retries to make when generating."""
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: Optional[float] = None
|
||||||
@ -811,14 +817,38 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
is_first_chunk = False
|
is_first_chunk = False
|
||||||
yield generation_chunk
|
yield generation_chunk
|
||||||
|
|
||||||
|
def _should_stream_usage(
|
||||||
|
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
|
) -> bool:
|
||||||
|
"""Determine whether to include usage metadata in streaming output.
|
||||||
|
|
||||||
|
For backwards compatibility, we check for `stream_options` passed
|
||||||
|
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
|
||||||
|
"""
|
||||||
|
stream_usage_sources = [ # order of precedence
|
||||||
|
stream_usage,
|
||||||
|
kwargs.get("stream_options", {}).get("include_usage"),
|
||||||
|
self.model_kwargs.get("stream_options", {}).get("include_usage"),
|
||||||
|
self.stream_usage,
|
||||||
|
]
|
||||||
|
for source in stream_usage_sources:
|
||||||
|
if isinstance(source, bool):
|
||||||
|
return source
|
||||||
|
return self.stream_usage
|
||||||
|
|
||||||
def _stream(
|
def _stream(
|
||||||
self,
|
self,
|
||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
*,
|
||||||
|
stream_usage: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
|
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||||
|
if stream_usage:
|
||||||
|
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
base_generation_info = {}
|
base_generation_info = {}
|
||||||
@ -1005,9 +1035,14 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
*,
|
||||||
|
stream_usage: Optional[bool] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
|
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||||
|
if stream_usage:
|
||||||
|
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
base_generation_info = {}
|
base_generation_info = {}
|
||||||
@ -2202,11 +2237,6 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
stream_usage: bool = False
|
|
||||||
"""Whether to include usage metadata in streaming output. If True, additional
|
|
||||||
message chunks will be generated during the stream including usage metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
|
||||||
"""Maximum number of tokens to generate."""
|
"""Maximum number of tokens to generate."""
|
||||||
|
|
||||||
@ -2268,55 +2298,21 @@ class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
|
|||||||
message["role"] = "developer"
|
message["role"] = "developer"
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _should_stream_usage(
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||||
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
"""Route to Chat Completions or Responses API."""
|
||||||
) -> bool:
|
|
||||||
"""Determine whether to include usage metadata in streaming output.
|
|
||||||
|
|
||||||
For backwards compatibility, we check for `stream_options` passed
|
|
||||||
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
|
|
||||||
"""
|
|
||||||
stream_usage_sources = [ # order of preference
|
|
||||||
stream_usage,
|
|
||||||
kwargs.get("stream_options", {}).get("include_usage"),
|
|
||||||
self.model_kwargs.get("stream_options", {}).get("include_usage"),
|
|
||||||
self.stream_usage,
|
|
||||||
]
|
|
||||||
for source in stream_usage_sources:
|
|
||||||
if isinstance(source, bool):
|
|
||||||
return source
|
|
||||||
return self.stream_usage
|
|
||||||
|
|
||||||
def _stream(
|
|
||||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
|
||||||
"""Set default stream_options."""
|
|
||||||
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||||
return super()._stream_responses(*args, **kwargs)
|
return super()._stream_responses(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
|
||||||
# Note: stream_options is not a valid parameter for Azure OpenAI.
|
|
||||||
# To support users proxying Azure through ChatOpenAI, here we only specify
|
|
||||||
# stream_options if include_usage is set to True.
|
|
||||||
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
|
|
||||||
# for release notes.
|
|
||||||
if stream_usage:
|
|
||||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
|
||||||
|
|
||||||
return super()._stream(*args, **kwargs)
|
return super()._stream(*args, **kwargs)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
self, *args: Any, **kwargs: Any
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
"""Set default stream_options."""
|
"""Route to Chat Completions or Responses API."""
|
||||||
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
if self._use_responses_api({**kwargs, **self.model_kwargs}):
|
||||||
async for chunk in super()._astream_responses(*args, **kwargs):
|
async for chunk in super()._astream_responses(*args, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
|
||||||
if stream_usage:
|
|
||||||
kwargs["stream_options"] = {"include_usage": stream_usage}
|
|
||||||
|
|
||||||
async for chunk in super()._astream(*args, **kwargs):
|
async for chunk in super()._astream(*args, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||||
|
|
||||||
@ -25,6 +24,7 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests):
|
|||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"openai_api_version": OPENAI_API_VERSION,
|
"openai_api_version": OPENAI_API_VERSION,
|
||||||
"azure_endpoint": OPENAI_API_BASE,
|
"azure_endpoint": OPENAI_API_BASE,
|
||||||
|
"stream_usage": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -35,10 +35,6 @@ class TestAzureOpenAIStandard(ChatModelIntegrationTests):
|
|||||||
def supports_json_mode(self) -> bool:
|
def supports_json_mode(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Not yet supported.")
|
|
||||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
|
||||||
super().test_usage_metadata_streaming(model)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
||||||
"""Test a legacy model."""
|
"""Test a legacy model."""
|
||||||
@ -53,12 +49,9 @@ class TestAzureOpenAIStandardLegacy(ChatModelIntegrationTests):
|
|||||||
"deployment_name": os.environ["AZURE_OPENAI_LEGACY_CHAT_DEPLOYMENT_NAME"],
|
"deployment_name": os.environ["AZURE_OPENAI_LEGACY_CHAT_DEPLOYMENT_NAME"],
|
||||||
"openai_api_version": OPENAI_API_VERSION,
|
"openai_api_version": OPENAI_API_VERSION,
|
||||||
"azure_endpoint": OPENAI_API_BASE,
|
"azure_endpoint": OPENAI_API_BASE,
|
||||||
|
"stream_usage": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def structured_output_kwargs(self) -> dict:
|
def structured_output_kwargs(self) -> dict:
|
||||||
return {"method": "function_calling"}
|
return {"method": "function_calling"}
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Not yet supported.")
|
|
||||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
|
||||||
super().test_usage_metadata_streaming(model)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user