openai: don't override stream_options default (#22242)

ChatOpenAI supports a kwarg `stream_options` which can take values
`{"include_usage": True}` and `{"include_usage": False}`.

Setting include_usage to True adds a message chunk to the end of the
stream with usage_metadata populated. In this case the final chunk no
longer includes `"finish_reason"` in the `response_metadata`. This is
the current default and is not yet released. Because this could be
disruptive to workflows, here we remove this default. The default will
now be consistent with OpenAI's API (see parameter
[here](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)).

Examples:
```python
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

for chunk in llm.stream("hi"):
    print(chunk)
```
```
content='' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92'
content='Hello' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92'
content='!' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92'
content='' response_metadata={'finish_reason': 'stop'} id='run-8cff4721-2acd-4551-9bf7-1911dae46b92'
```

```python
for chunk in llm.stream("hi", stream_options={"include_usage": True}):
    print(chunk)
```
```
content='' id='run-39ab349b-f954-464d-af6e-72a0927daa27'
content='Hello' id='run-39ab349b-f954-464d-af6e-72a0927daa27'
content='!' id='run-39ab349b-f954-464d-af6e-72a0927daa27'
content='' response_metadata={'finish_reason': 'stop'} id='run-39ab349b-f954-464d-af6e-72a0927daa27'
content='' id='run-39ab349b-f954-464d-af6e-72a0927daa27' usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17}
```

```python
llm = ChatOpenAI().bind(stream_options={"include_usage": True})

for chunk in llm.stream("hi"):
    print(chunk)
```
```
content='' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d'
content='Hello' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d'
content='!' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d'
content='' response_metadata={'finish_reason': 'stop'} id='run-59918845-04b2-41a6-8d90-f75fb4506e0d'
content='' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17}
```
This commit is contained in:
ccurme 2024-05-29 10:30:40 -04:00 committed by GitHub
parent a1899439fc
commit af1f723ada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 53 deletions

View File

@ -1162,29 +1162,6 @@ class ChatOpenAI(BaseChatOpenAI):
"""Return whether this model can be serialized by Langchain."""
return True
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
default_stream_options = {"include_usage": True}
stream_options = kwargs.get("stream_options", {})
merged_stream_options = {**default_stream_options, **stream_options}
kwargs["stream_options"] = merged_stream_options
return super()._stream(*args, **kwargs)
async def _astream(
self,
*args: Any,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
default_stream_options = {"include_usage": True}
stream_options = kwargs.get("stream_options", {})
merged_stream_options = {**default_stream_options, **stream_options}
kwargs["stream_options"] = merged_stream_options
async for chunk in super()._astream(*args, **kwargs):
yield chunk
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)

View File

@ -346,10 +346,18 @@ def test_stream() -> None:
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("finish_reason") is not None
# check token usage
aggregate: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
for chunk in llm.stream("Hello", stream_options={"include_usage": True}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
@ -359,21 +367,11 @@ def test_stream() -> None:
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
# check not populated
aggregate: Optional[BaseMessageChunk] = None
for chunk in llm.stream("Hello", stream_options={"include_usage": False}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(aggregate, AIMessageChunk)
assert aggregate.usage_metadata is None
assert aggregate.usage_metadata is not None
assert aggregate.usage_metadata["input_tokens"] > 0
assert aggregate.usage_metadata["output_tokens"] > 0
assert aggregate.usage_metadata["total_tokens"] > 0
async def test_astream() -> None:
@ -381,10 +379,18 @@ async def test_astream() -> None:
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("finish_reason") is not None
# check token usage
aggregate: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for chunk in llm.astream("Hello", stream_options={"include_usage": True}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
@ -394,21 +400,11 @@ async def test_astream() -> None:
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
# check not populated
aggregate: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("Hello", stream_options={"include_usage": False}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(aggregate, AIMessageChunk)
assert aggregate.usage_metadata is None
assert aggregate.usage_metadata is not None
assert aggregate.usage_metadata["input_tokens"] > 0
assert aggregate.usage_metadata["output_tokens"] > 0
assert aggregate.usage_metadata["total_tokens"] > 0
async def test_abatch() -> None: