mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
openai: read stream_options (#21548)
OpenAI recently added a `stream_options` parameter to its chat completions API (see [release notes](https://platform.openai.com/docs/changelog/added-chat-completions-stream-usage)). When this parameter is set to `{"usage": True}`, an extra "empty" message is added to the end of a stream containing token usage. Here we propagate token usage to `AIMessage.usage_metadata`. We enable this feature by default. Streams would now include an extra chunk at the end, **after** the chunk with `response_metadata={'finish_reason': 'stop'}`. New behavior: ``` [AIMessageChunk(content='', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'), AIMessageChunk(content='Hello', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'), AIMessageChunk(content='!', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'), AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde'), AIMessageChunk(content='', id='run-4b20dbe0-3817-4f62-b89d-03ef76f25bde', usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17})] ``` Old behavior (accessible by passing `stream_options={"include_usage": False}` into (a)stream: ``` [AIMessageChunk(content='', id='run-1312b971-c5ea-4d92-9015-e6604535f339'), AIMessageChunk(content='Hello', id='run-1312b971-c5ea-4d92-9015-e6604535f339'), AIMessageChunk(content='!', id='run-1312b971-c5ea-4d92-9015-e6604535f339'), AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-1312b971-c5ea-4d92-9015-e6604535f339')] ``` From what I can tell this is not yet implemented in Azure, so we enable only for ChatOpenAI.
This commit is contained in:
parent
eb7c767e5b
commit
9a010fb761
@ -58,6 +58,7 @@ from langchain_core.messages import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
JsonOutputParser,
|
JsonOutputParser,
|
||||||
PydanticOutputParser,
|
PydanticOutputParser,
|
||||||
@ -483,23 +484,36 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
if token_usage := chunk.get("usage"):
|
||||||
choice = chunk["choices"][0]
|
usage_metadata = UsageMetadata(
|
||||||
if choice["delta"] is None:
|
input_tokens=token_usage.get("prompt_tokens", 0),
|
||||||
continue
|
output_tokens=token_usage.get("completion_tokens", 0),
|
||||||
chunk = _convert_delta_to_message_chunk(
|
total_tokens=token_usage.get("total_tokens", 0),
|
||||||
choice["delta"], default_chunk_class
|
)
|
||||||
)
|
chunk = ChatGenerationChunk(
|
||||||
generation_info = {}
|
message=default_chunk_class(
|
||||||
if finish_reason := choice.get("finish_reason"):
|
content="", usage_metadata=usage_metadata
|
||||||
generation_info["finish_reason"] = finish_reason
|
)
|
||||||
logprobs = choice.get("logprobs")
|
)
|
||||||
if logprobs:
|
else:
|
||||||
generation_info["logprobs"] = logprobs
|
continue
|
||||||
default_chunk_class = chunk.__class__
|
else:
|
||||||
chunk = ChatGenerationChunk(
|
choice = chunk["choices"][0]
|
||||||
message=chunk, generation_info=generation_info or None
|
if choice["delta"] is None:
|
||||||
)
|
continue
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
generation_info = {}
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(
|
run_manager.on_llm_new_token(
|
||||||
chunk.text, chunk=chunk, logprobs=logprobs
|
chunk.text, chunk=chunk, logprobs=logprobs
|
||||||
@ -589,23 +603,36 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
if len(chunk["choices"]) == 0:
|
if len(chunk["choices"]) == 0:
|
||||||
continue
|
if token_usage := chunk.get("usage"):
|
||||||
choice = chunk["choices"][0]
|
usage_metadata = UsageMetadata(
|
||||||
if choice["delta"] is None:
|
input_tokens=token_usage.get("prompt_tokens", 0),
|
||||||
continue
|
output_tokens=token_usage.get("completion_tokens", 0),
|
||||||
chunk = _convert_delta_to_message_chunk(
|
total_tokens=token_usage.get("total_tokens", 0),
|
||||||
choice["delta"], default_chunk_class
|
)
|
||||||
)
|
chunk = ChatGenerationChunk(
|
||||||
generation_info = {}
|
message=default_chunk_class(
|
||||||
if finish_reason := choice.get("finish_reason"):
|
content="", usage_metadata=usage_metadata
|
||||||
generation_info["finish_reason"] = finish_reason
|
)
|
||||||
logprobs = choice.get("logprobs")
|
)
|
||||||
if logprobs:
|
else:
|
||||||
generation_info["logprobs"] = logprobs
|
continue
|
||||||
default_chunk_class = chunk.__class__
|
else:
|
||||||
chunk = ChatGenerationChunk(
|
choice = chunk["choices"][0]
|
||||||
message=chunk, generation_info=generation_info or None
|
if choice["delta"] is None:
|
||||||
)
|
continue
|
||||||
|
chunk = _convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
generation_info = {}
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
logprobs = choice.get("logprobs")
|
||||||
|
if logprobs:
|
||||||
|
generation_info["logprobs"] = logprobs
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(
|
||||||
|
message=chunk, generation_info=generation_info or None
|
||||||
|
)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(
|
await run_manager.on_llm_new_token(
|
||||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
token=chunk.text, chunk=chunk, logprobs=logprobs
|
||||||
@ -1135,6 +1162,29 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||||||
"""Return whether this model can be serialized by Langchain."""
|
"""Return whether this model can be serialized by Langchain."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||||
|
"""Set default stream_options."""
|
||||||
|
default_stream_options = {"include_usage": True}
|
||||||
|
stream_options = kwargs.get("stream_options", {})
|
||||||
|
merged_stream_options = {**default_stream_options, **stream_options}
|
||||||
|
kwargs["stream_options"] = merged_stream_options
|
||||||
|
|
||||||
|
return super()._stream(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
"""Set default stream_options."""
|
||||||
|
default_stream_options = {"include_usage": True}
|
||||||
|
stream_options = kwargs.get("stream_options", {})
|
||||||
|
merged_stream_options = {**default_stream_options, **stream_options}
|
||||||
|
kwargs["stream_options"] = merged_stream_options
|
||||||
|
|
||||||
|
async for chunk in super()._astream(*args, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||||
|
2
libs/partners/openai/poetry.lock
generated
2
libs/partners/openai/poetry.lock
generated
@ -1268,4 +1268,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "0d142c2f1c4d08dbdcbd15b8005ffa606d4d36781c52e9411993cabd41de261b"
|
content-hash = "62f0a24221c64dc8035ccf7cca3f8ac2eaf47d653a441645c8021120833ecb52"
|
||||||
|
@ -13,7 +13,7 @@ license = "MIT"
|
|||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
langchain-core = {version =">=0.2.2rc1,<0.3", allow-prereleases=true}
|
langchain-core = {version =">=0.2.2rc1,<0.3", allow-prereleases=true}
|
||||||
openai = "^1.24.0"
|
openai = "^1.26.0"
|
||||||
tiktoken = ">=0.7,<1"
|
tiktoken = ">=0.7,<1"
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
[tool.poetry.group.test]
|
||||||
|
@ -346,9 +346,34 @@ def test_stream() -> None:
|
|||||||
llm = ChatOpenAI()
|
llm = ChatOpenAI()
|
||||||
|
|
||||||
full: Optional[BaseMessageChunk] = None
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
chunks_with_token_counts = 0
|
||||||
for chunk in llm.stream("I'm Pickle Rick"):
|
for chunk in llm.stream("I'm Pickle Rick"):
|
||||||
assert isinstance(chunk.content, str)
|
assert isinstance(chunk.content, str)
|
||||||
full = chunk if full is None else full + chunk
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
if chunk.usage_metadata is not None:
|
||||||
|
chunks_with_token_counts += 1
|
||||||
|
if chunks_with_token_counts != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected exactly one chunk with token counts. "
|
||||||
|
"AIMessageChunk aggregation adds counts. Check that "
|
||||||
|
"this is behaving properly."
|
||||||
|
)
|
||||||
|
|
||||||
|
# check token usage is populated
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.usage_metadata is not None
|
||||||
|
assert full.usage_metadata["input_tokens"] > 0
|
||||||
|
assert full.usage_metadata["output_tokens"] > 0
|
||||||
|
assert full.usage_metadata["total_tokens"] > 0
|
||||||
|
|
||||||
|
# check not populated
|
||||||
|
aggregate: Optional[BaseMessageChunk] = None
|
||||||
|
for chunk in llm.stream("Hello", stream_options={"include_usage": False}):
|
||||||
|
assert isinstance(chunk.content, str)
|
||||||
|
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||||
|
assert isinstance(aggregate, AIMessageChunk)
|
||||||
|
assert aggregate.usage_metadata is None
|
||||||
|
|
||||||
|
|
||||||
async def test_astream() -> None:
|
async def test_astream() -> None:
|
||||||
@ -356,9 +381,34 @@ async def test_astream() -> None:
|
|||||||
llm = ChatOpenAI()
|
llm = ChatOpenAI()
|
||||||
|
|
||||||
full: Optional[BaseMessageChunk] = None
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
chunks_with_token_counts = 0
|
||||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
async for chunk in llm.astream("I'm Pickle Rick"):
|
||||||
assert isinstance(chunk.content, str)
|
assert isinstance(chunk.content, str)
|
||||||
full = chunk if full is None else full + chunk
|
full = chunk if full is None else full + chunk
|
||||||
|
assert isinstance(chunk, AIMessageChunk)
|
||||||
|
if chunk.usage_metadata is not None:
|
||||||
|
chunks_with_token_counts += 1
|
||||||
|
if chunks_with_token_counts != 1:
|
||||||
|
raise AssertionError(
|
||||||
|
"Expected exactly one chunk with token counts. "
|
||||||
|
"AIMessageChunk aggregation adds counts. Check that "
|
||||||
|
"this is behaving properly."
|
||||||
|
)
|
||||||
|
|
||||||
|
# check token usage is populated
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.usage_metadata is not None
|
||||||
|
assert full.usage_metadata["input_tokens"] > 0
|
||||||
|
assert full.usage_metadata["output_tokens"] > 0
|
||||||
|
assert full.usage_metadata["total_tokens"] > 0
|
||||||
|
|
||||||
|
# check not populated
|
||||||
|
aggregate: Optional[BaseMessageChunk] = None
|
||||||
|
async for chunk in llm.astream("Hello", stream_options={"include_usage": False}):
|
||||||
|
assert isinstance(chunk.content, str)
|
||||||
|
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||||
|
assert isinstance(aggregate, AIMessageChunk)
|
||||||
|
assert aggregate.usage_metadata is None
|
||||||
|
|
||||||
|
|
||||||
async def test_abatch() -> None:
|
async def test_abatch() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user