openai: read stream_options (#21548)

OpenAI recently added a `stream_options` parameter to its chat
completions API (see [release
notes](https://platform.openai.com/docs/changelog/added-chat-completions-stream-usage)).
When this parameter is set to `{"usage": True}`, an extra "empty"
message is added to the end of a stream containing token usage. Here we
propagate token usage to `AIMessage.usage_metadata`.

We enable this feature by default. Streams would now include an extra
chunk at the end, **after** the chunk with
`response_metadata={'finish_reason': 'stop'}`.

New behavior:
```
[AIMessageChunk(content='', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'),
 AIMessageChunk(content='Hello', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'),
 AIMessageChunk(content='!', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'),
 AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'),
 AIMessageChunk(content='', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde', usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17})]
```

Old behavior (accessible by passing `stream_options={"include_usage":
False}` into (a)stream:
```
[AIMessageChunk(content='', id='run-1312b971-c5ea-4d92-9015-e6604535f339'),
 AIMessageChunk(content='Hello', id='run-1312b971-c5ea-4d92-9015-e6604535f339'),
 AIMessageChunk(content='!', id='run-1312b971-c5ea-4d92-9015-e6604535f339'),
 AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-1312b971-c5ea-4d92-9015-e6604535f339')]
```

From what I can tell this is not yet implemented in Azure, so we enable
only for ChatOpenAI.
This commit is contained in:
ccurme 2024-05-24 13:20:56 -04:00 committed by GitHub
parent eb7c767e5b
commit 9a010fb761
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 136 additions and 36 deletions

View File

@ -58,6 +58,7 @@ from langchain_core.messages import (
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
@ -483,7 +484,20 @@ class BaseChatOpenAI(BaseChatModel):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
usage_metadata = UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
chunk = ChatGenerationChunk(
message=default_chunk_class(
content="", usage_metadata=usage_metadata
)
)
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
@ -589,7 +603,20 @@ class BaseChatOpenAI(BaseChatModel):
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
if token_usage := chunk.get("usage"):
usage_metadata = UsageMetadata(
input_tokens=token_usage.get("prompt_tokens", 0),
output_tokens=token_usage.get("completion_tokens", 0),
total_tokens=token_usage.get("total_tokens", 0),
)
chunk = ChatGenerationChunk(
message=default_chunk_class(
content="", usage_metadata=usage_metadata
)
)
else:
continue
else:
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
@ -1135,6 +1162,29 @@ class ChatOpenAI(BaseChatOpenAI):
"""Return whether this model can be serialized by Langchain."""
return True
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
"""Set default stream_options."""
default_stream_options = {"include_usage": True}
stream_options = kwargs.get("stream_options", {})
merged_stream_options = {**default_stream_options, **stream_options}
kwargs["stream_options"] = merged_stream_options
return super()._stream(*args, **kwargs)
async def _astream(
self,
*args: Any,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Set default stream_options."""
default_stream_options = {"include_usage": True}
stream_options = kwargs.get("stream_options", {})
merged_stream_options = {**default_stream_options, **stream_options}
kwargs["stream_options"] = merged_stream_options
async for chunk in super()._astream(*args, **kwargs):
yield chunk
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)

View File

@ -1268,4 +1268,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "0d142c2f1c4d08dbdcbd15b8005ffa606d4d36781c52e9411993cabd41de261b"
content-hash = "62f0a24221c64dc8035ccf7cca3f8ac2eaf47d653a441645c8021120833ecb52"

View File

@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = {version =">=0.2.2rc1,<0.3", allow-prereleases=true}
openai = "^1.24.0"
openai = "^1.26.0"
tiktoken = ">=0.7,<1"
[tool.poetry.group.test]

View File

@ -346,9 +346,34 @@ def test_stream() -> None:
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
for chunk in llm.stream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
# check not populated
aggregate: Optional[BaseMessageChunk] = None
for chunk in llm.stream("Hello", stream_options={"include_usage": False}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(aggregate, AIMessageChunk)
assert aggregate.usage_metadata is None
async def test_astream() -> None:
@ -356,9 +381,34 @@ async def test_astream() -> None:
llm = ChatOpenAI()
full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for chunk in llm.astream("I'm Pickle Rick"):
assert isinstance(chunk.content, str)
full = chunk if full is None else full + chunk
assert isinstance(chunk, AIMessageChunk)
if chunk.usage_metadata is not None:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
# check token usage is populated
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert full.usage_metadata["total_tokens"] > 0
# check not populated
aggregate: Optional[BaseMessageChunk] = None
async for chunk in llm.astream("Hello", stream_options={"include_usage": False}):
assert isinstance(chunk.content, str)
aggregate = chunk if aggregate is None else aggregate + chunk
assert isinstance(aggregate, AIMessageChunk)
assert aggregate.usage_metadata is None
async def test_abatch() -> None: