openai: don't override stream_options default (#22242)

ChatOpenAI supports a kwarg `stream_options` which can take values
`{"include_usage": True}` and `{"include_usage": False}`.

Setting include_usage to True adds a message chunk to the end of the
stream with usage_metadata populated. In this case the final chunk no
longer includes `"finish_reason"` in the `response_metadata`. This is
the current default and is not yet released. Because this could be
disruptive to workflows, here we remove this default. The default will
now be consistent with OpenAI's API (see parameter
[here](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)).

Examples:
```python
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

for chunk in llm.stream("hi"):
    print(chunk)
```
```
content='' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92'
content='Hello' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92'
content='!' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92'
content='' response_metadata={'finish_reason': 'stop'} id='run-8cff4721-2acd-4551-9bf7-1911dae46b92'
```

```python
for chunk in llm.stream("hi", stream_options={"include_usage": True}):
    print(chunk)
```
```
content='' id='run-39ab349b-f954-464d-af6e-72a0927daa27'
content='Hello' id='run-39ab349b-f954-464d-af6e-72a0927daa27'
content='!' id='run-39ab349b-f954-464d-af6e-72a0927daa27'
content='' response_metadata={'finish_reason': 'stop'} id='run-39ab349b-f954-464d-af6e-72a0927daa27'
content='' id='run-39ab349b-f954-464d-af6e-72a0927daa27' usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17}
```

```python
llm = ChatOpenAI().bind(stream_options={"include_usage": True})

for chunk in llm.stream("hi"):
    print(chunk)
```
```
content='' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d'
content='Hello' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d'
content='!' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d'
content='' response_metadata={'finish_reason': 'stop'} id='run-59918845-04b2-41a6-8d90-f75fb4506e0d'
content='' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17}
```
This commit is contained in:
ccurme 2024-05-29 10:30:40 -04:00 committed by GitHub
parent a1899439fc
commit af1f723ada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 53 deletions

View File

@ -1162,29 +1162,6 @@ class ChatOpenAI(BaseChatOpenAI):
"""Return whether this model can be serialized by Langchain.""" """Return whether this model can be serialized by Langchain."""
return True return True
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
default_stream_options = {"include_usage": True}
stream_options = kwargs.get("stream_options", {})
merged_stream_options = {**default_stream_options, **stream_options}
kwargs["stream_options"] = merged_stream_options
return super()._stream(*args, **kwargs)
async def _astream(
self,
*args: Any,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
default_stream_options = {"include_usage": True}
stream_options = kwargs.get("stream_options", {})
merged_stream_options = {**default_stream_options, **stream_options}
kwargs["stream_options"] = merged_stream_options
async for chunk in super()._astream(*args, **kwargs):
yield chunk
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel) return isinstance(obj, type) and issubclass(obj, BaseModel)

View File

@ -346,10 +346,18 @@ def test_stream() -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
for chunk in llm.stream("I'm Pickle Rick"): for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(chunk.content, str) assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("finish_reason") is not None
# check token usage
aggregate: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
for chunk in llm.stream("Hello", stream_options={"include_usage": True}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None: if chunk.usage_metadata is not None:
chunks_with_token_counts += 1 chunks_with_token_counts += 1
@ -359,21 +367,11 @@ def test_stream() -> None:
"AIMessageChunk aggregation adds counts. Check that " "AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly." "this is behaving properly."
) )
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
# check not populated
aggregate: Optional[BaseMessageChunk] = None
for chunk in llm.stream("Hello", stream_options={"include_usage": False}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(aggregate, AIMessageChunk) assert isinstance(aggregate, AIMessageChunk)
assert aggregate.usage_metadata is None assert aggregate.usage_metadata is not None
assert aggregate.usage_metadata["input_tokens"] > 0
assert aggregate.usage_metadata["output_tokens"] > 0
assert aggregate.usage_metadata["total_tokens"] > 0
async def test_astream() -> None: async def test_astream() -> None:
@ -381,10 +379,18 @@ async def test_astream() -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for chunk in llm.astream("I'm Pickle Rick"): async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str) assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
assert isinstance(full, AIMessageChunk)
assert full.response_metadata.get("finish_reason") is not None
# check token usage
aggregate: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for chunk in llm.astream("Hello", stream_options={"include_usage": True}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(chunk, AIMessageChunk) assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None: if chunk.usage_metadata is not None:
chunks_with_token_counts += 1 chunks_with_token_counts += 1
@ -394,21 +400,11 @@ async def test_astream() -> None:
"AIMessageChunk aggregation adds counts. Check that " "AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly." "this is behaving properly."
) )
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
# check not populated
aggregate: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("Hello", stream_options={"include_usage": False}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(aggregate, AIMessageChunk) assert isinstance(aggregate, AIMessageChunk)
assert aggregate.usage_metadata is None assert aggregate.usage_metadata is not None
assert aggregate.usage_metadata["input_tokens"] > 0
assert aggregate.usage_metadata["output_tokens"] > 0
assert aggregate.usage_metadata["total_tokens"] > 0
async def test_abatch() -> None: async def test_abatch() -> None: