mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 03:01:29 +00:00
openai: read stream_options (#21548)
OpenAI recently added a `stream_options` parameter to its chat completions API (see [release notes](https://platform.openai.com/docs/changelog/added-chat-completions-stream-usage)). When this parameter is set to `{"usage": True}`, an extra "empty" message is added to the end of a stream containing token usage. Here we propagate token usage to `AIMessage.usage_metadata`. We enable this feature by default. Streams would now include an extra chunk at the end, **after** the chunk with `response_metadata={'finish_reason': 'stop'}`. New behavior: ``` [AIMessageChunk(content='', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'), AIMessageChunk(content='Hello', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'), AIMessageChunk(content='!', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'), AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'), AIMessageChunk(content='', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde', usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17})] ``` Old behavior (accessible by passing `stream_options={"include_usage": False}` into (a)stream: ``` [AIMessageChunk(content='', id='run-1312b971-c5ea-4d92-9015-e6604535f339'), AIMessageChunk(content='Hello', id='run-1312b971-c5ea-4d92-9015-e6604535f339'), AIMessageChunk(content='!', id='run-1312b971-c5ea-4d92-9015-e6604535f339'), AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-1312b971-c5ea-4d92-9015-e6604535f339')] ``` From what I can tell this is not yet implemented in Azure, so we enable only for ChatOpenAI.
This commit is contained in:
parent
eb7c767e5b
commit
9a010fb761
@ -58,6 +58,7 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.output_parsers import (
|
||||
JsonOutputParser,
|
||||
PydanticOutputParser,
|
||||
@ -483,23 +484,36 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if token_usage := chunk.get("usage"):
|
||||
usage_metadata = UsageMetadata(
|
||||
input_tokens=token_usage.get("prompt_tokens", 0),
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text, chunk=chunk, logprobs=logprobs
|
||||
@ -589,23 +603,36 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.model_dump()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if token_usage := chunk.get("usage"):
|
||||
usage_metadata = UsageMetadata(
|
||||
input_tokens=token_usage.get("prompt_tokens", 0),
|
||||
output_tokens=token_usage.get("completion_tokens", 0),
|
||||
total_tokens=token_usage.get("total_tokens", 0),
|
||||
)
|
||||
chunk = ChatGenerationChunk(
|
||||
message=default_chunk_class(
|
||||
content="", usage_metadata=usage_metadata
|
||||
)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
choice = chunk["choices"][0]
|
||||
if choice["delta"] is None:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"], default_chunk_class
|
||||
)
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
logprobs = choice.get("logprobs")
|
||||
if logprobs:
|
||||
generation_info["logprobs"] = logprobs
|
||||
default_chunk_class = chunk.__class__
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk, generation_info=generation_info or None
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||
@ -1135,6 +1162,29 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
default_stream_options = {"include_usage": True}
|
||||
stream_options = kwargs.get("stream_options", {})
|
||||
merged_stream_options = {**default_stream_options, **stream_options}
|
||||
kwargs["stream_options"] = merged_stream_options
|
||||
|
||||
return super()._stream(*args, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
default_stream_options = {"include_usage": True}
|
||||
stream_options = kwargs.get("stream_options", {})
|
||||
merged_stream_options = {**default_stream_options, **stream_options}
|
||||
kwargs["stream_options"] = merged_stream_options
|
||||
|
||||
async for chunk in super()._astream(*args, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
|
2
libs/partners/openai/poetry.lock
generated
2
libs/partners/openai/poetry.lock
generated
@ -1268,4 +1268,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "0d142c2f1c4d08dbdcbd15b8005ffa606d4d36781c52e9411993cabd41de261b"
|
||||
content-hash = "62f0a24221c64dc8035ccf7cca3f8ac2eaf47d653a441645c8021120833ecb52"
|
||||
|
@ -13,7 +13,7 @@ license = "MIT"
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = {version =">=0.2.2rc1,<0.3", allow-prereleases=true}
|
||||
openai = "^1.24.0"
|
||||
openai = "^1.26.0"
|
||||
tiktoken = ">=0.7,<1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
|
@ -346,9 +346,34 @@ def test_stream() -> None:
|
||||
llm = ChatOpenAI()
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
for chunk in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if chunks_with_token_counts != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
|
||||
# check token usage is populated
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert full.usage_metadata["total_tokens"] > 0
|
||||
|
||||
# check not populated
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream("Hello", stream_options={"include_usage": False}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(aggregate, AIMessageChunk)
|
||||
assert aggregate.usage_metadata is None
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
@ -356,9 +381,34 @@ async def test_astream() -> None:
|
||||
llm = ChatOpenAI()
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
if chunks_with_token_counts != 1:
|
||||
raise AssertionError(
|
||||
"Expected exactly one chunk with token counts. "
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
|
||||
# check token usage is populated
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert full.usage_metadata["total_tokens"] > 0
|
||||
|
||||
# check not populated
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
async for chunk in llm.astream("Hello", stream_options={"include_usage": False}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(aggregate, AIMessageChunk)
|
||||
assert aggregate.usage_metadata is None
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user