mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
openai[patch]: add stream_usage parameter (#22854)
Here we add `stream_usage` to ChatOpenAI as:
1. a boolean attribute
2. a kwarg to _stream and _astream.
Question: should the `stream_usage` attribute be `bool`, or `bool |
None`?
Currently I've kept it `bool` and defaulted to False. It was implemented
on
[ChatAnthropic](e832bbb486/libs/partners/anthropic/langchain_anthropic/chat_models.py (L535)
)
as a bool. However, to maintain support for users who access the
behavior via OpenAI's `stream_options` param, this ends up being
possible:
```python
llm = ChatOpenAI(model_kwargs={"stream_options": {"include_usage": True}})
assert not llm.stream_usage
```
(and this model will stream token usage).
Some options for this:
- it's ok
- make the `stream_usage` attribute bool or None
- make an \_\_init\_\_ for ChatOpenAI, set a `._stream_usage` attribute
and read `.stream_usage` from a property
Open to other ideas as well.
This commit is contained in:
parent
56ac94e014
commit
722c8f50ea
@ -495,6 +495,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
content="", usage_metadata=usage_metadata
|
content="", usage_metadata=usage_metadata
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
logprobs = None
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -619,6 +620,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
content="", usage_metadata=usage_metadata
|
content="", usage_metadata=usage_metadata
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
logprobs = None
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -1386,11 +1388,11 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||||||
|
|
||||||
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
|
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
|
||||||
|
|
||||||
When streaming, set the ``stream_options`` model kwarg:
|
When streaming, set the ``stream_usage`` kwarg:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
stream = llm.stream(messages, stream_options={"include_usage": True})
|
stream = llm.stream(messages, stream_usage=True)
|
||||||
full = next(stream)
|
full = next(stream)
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
full += chunk
|
full += chunk
|
||||||
@ -1400,7 +1402,7 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||||||
|
|
||||||
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
|
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
|
||||||
|
|
||||||
Alternatively, setting ``stream_options`` when instantiating the model can be
|
Alternatively, setting ``stream_usage`` when instantiating the model can be
|
||||||
useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using
|
useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using
|
||||||
methods like ``.with_structured_output``, which generate chains under the
|
methods like ``.with_structured_output``, which generate chains under the
|
||||||
hood.
|
hood.
|
||||||
@ -1409,7 +1411,7 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||||||
|
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
model="gpt-4o",
|
model="gpt-4o",
|
||||||
model_kwargs={"stream_options": {"include_usage": True}},
|
stream_usage=True,
|
||||||
)
|
)
|
||||||
structured_llm = llm.with_structured_output(...)
|
structured_llm = llm.with_structured_output(...)
|
||||||
|
|
||||||
@ -1446,6 +1448,11 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
stream_usage: bool = False
|
||||||
|
"""Whether to include usage metadata in streaming output. If True, additional
|
||||||
|
message chunks will be generated during the stream including usage metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lc_secrets(self) -> Dict[str, str]:
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
@ -1475,6 +1482,44 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||||||
"""Return whether this model can be serialized by Langchain."""
|
"""Return whether this model can be serialized by Langchain."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _should_stream_usage(
|
||||||
|
self, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
|
) -> bool:
|
||||||
|
"""Determine whether to include usage metadata in streaming output.
|
||||||
|
|
||||||
|
For backwards compatibility, we check for `stream_options` passed
|
||||||
|
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
|
||||||
|
"""
|
||||||
|
stream_usage_sources = [ # order of preference
|
||||||
|
stream_usage,
|
||||||
|
kwargs.get("stream_options", {}).get("include_usage"),
|
||||||
|
self.model_kwargs.get("stream_options", {}).get("include_usage"),
|
||||||
|
self.stream_usage,
|
||||||
|
]
|
||||||
|
for source in stream_usage_sources:
|
||||||
|
if isinstance(source, bool):
|
||||||
|
return source
|
||||||
|
return self.stream_usage
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
"""Set default stream_options."""
|
||||||
|
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||||
|
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||||
|
|
||||||
|
return super()._stream(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
"""Set default stream_options."""
|
||||||
|
stream_usage = self._should_stream_usage(stream_usage, **kwargs)
|
||||||
|
kwargs["stream_options"] = {"include_usage": stream_usage}
|
||||||
|
|
||||||
|
async for chunk in super()._astream(*args, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Test ChatOpenAI chat model."""
|
"""Test ChatOpenAI chat model."""
|
||||||
from typing import Any, List, Optional, cast
|
from typing import Any, AsyncIterator, List, Optional, cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.callbacks import CallbackManager
|
from langchain_core.callbacks import CallbackManager
|
||||||
@ -357,7 +357,7 @@ def test_stream() -> None:
|
|||||||
aggregate: Optional[BaseMessageChunk] = None
|
aggregate: Optional[BaseMessageChunk] = None
|
||||||
chunks_with_token_counts = 0
|
chunks_with_token_counts = 0
|
||||||
chunks_with_response_metadata = 0
|
chunks_with_response_metadata = 0
|
||||||
for chunk in llm.stream("Hello", stream_options={"include_usage": True}):
|
for chunk in llm.stream("Hello", stream_usage=True):
|
||||||
assert isinstance(chunk.content, str)
|
assert isinstance(chunk.content, str)
|
||||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||||
assert isinstance(chunk, AIMessageChunk)
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
@ -380,39 +380,73 @@ def test_stream() -> None:
|
|||||||
|
|
||||||
async def test_astream() -> None:
|
async def test_astream() -> None:
|
||||||
"""Test streaming tokens from OpenAI."""
|
"""Test streaming tokens from OpenAI."""
|
||||||
llm = ChatOpenAI()
|
|
||||||
|
|
||||||
full: Optional[BaseMessageChunk] = None
|
async def _test_stream(stream: AsyncIterator, expect_usage: bool) -> None:
|
||||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
full: Optional[BaseMessageChunk] = None
|
||||||
assert isinstance(chunk.content, str)
|
chunks_with_token_counts = 0
|
||||||
full = chunk if full is None else full + chunk
|
chunks_with_response_metadata = 0
|
||||||
assert isinstance(full, AIMessageChunk)
|
async for chunk in stream:
|
||||||
assert full.response_metadata.get("finish_reason") is not None
|
assert isinstance(chunk.content, str)
|
||||||
assert full.response_metadata.get("model_name") is not None
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
if chunk.usage_metadata is not None:
|
||||||
|
chunks_with_token_counts += 1
|
||||||
|
if chunk.response_metadata:
|
||||||
|
chunks_with_response_metadata += 1
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
if chunks_with_response_metadata != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected exactly one chunk with metadata. "
|
||||||
|
"AIMessageChunk aggregation can add these metadata. Check that "
|
||||||
|
"this is behaving properly."
|
||||||
|
)
|
||||||
|
assert full.response_metadata.get("finish_reason") is not None
|
||||||
|
assert full.response_metadata.get("model_name") is not None
|
||||||
|
if expect_usage:
|
||||||
|
if chunks_with_token_counts != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected exactly one chunk with token counts. "
|
||||||
|
"AIMessageChunk aggregation adds counts. Check that "
|
||||||
|
"this is behaving properly."
|
||||||
|
)
|
||||||
|
assert full.usage_metadata is not None
|
||||||
|
assert full.usage_metadata["input_tokens"] > 0
|
||||||
|
assert full.usage_metadata["output_tokens"] > 0
|
||||||
|
assert full.usage_metadata["total_tokens"] > 0
|
||||||
|
else:
|
||||||
|
assert chunks_with_token_counts == 0
|
||||||
|
assert full.usage_metadata is None
|
||||||
|
|
||||||
# check token usage
|
llm = ChatOpenAI(temperature=0, max_tokens=5)
|
||||||
aggregate: Optional[BaseMessageChunk] = None
|
await _test_stream(llm.astream("Hello"), expect_usage=False)
|
||||||
chunks_with_token_counts = 0
|
await _test_stream(
|
||||||
chunks_with_response_metadata = 0
|
llm.astream("Hello", stream_options={"include_usage": True}),
|
||||||
async for chunk in llm.astream("Hello", stream_options={"include_usage": True}):
|
expect_usage=True,
|
||||||
assert isinstance(chunk.content, str)
|
)
|
||||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
await _test_stream(
|
||||||
assert isinstance(chunk, AIMessageChunk)
|
llm.astream("Hello", stream_usage=True),
|
||||||
if chunk.usage_metadata is not None:
|
expect_usage=True,
|
||||||
chunks_with_token_counts += 1
|
)
|
||||||
if chunk.response_metadata:
|
llm = ChatOpenAI(
|
||||||
chunks_with_response_metadata += 1
|
temperature=0,
|
||||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
max_tokens=5,
|
||||||
raise AssertionError(
|
model_kwargs={"stream_options": {"include_usage": True}},
|
||||||
"Expected exactly one chunk with metadata. "
|
)
|
||||||
"AIMessageChunk aggregation can add these metadata. Check that "
|
await _test_stream(llm.astream("Hello"), expect_usage=True)
|
||||||
"this is behaving properly."
|
await _test_stream(
|
||||||
)
|
llm.astream("Hello", stream_options={"include_usage": False}),
|
||||||
assert isinstance(aggregate, AIMessageChunk)
|
expect_usage=False,
|
||||||
assert aggregate.usage_metadata is not None
|
)
|
||||||
assert aggregate.usage_metadata["input_tokens"] > 0
|
llm = ChatOpenAI(
|
||||||
assert aggregate.usage_metadata["output_tokens"] > 0
|
temperature=0,
|
||||||
assert aggregate.usage_metadata["total_tokens"] > 0
|
max_tokens=5,
|
||||||
|
stream_usage=True,
|
||||||
|
)
|
||||||
|
await _test_stream(llm.astream("Hello"), expect_usage=True)
|
||||||
|
await _test_stream(
|
||||||
|
llm.astream("Hello", stream_usage=False),
|
||||||
|
expect_usage=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_abatch() -> None:
|
async def test_abatch() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user