mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
openai: don't override stream_options default (#22242)
ChatOpenAI supports a kwarg `stream_options` which can take values `{"include_usage": True}` and `{"include_usage": False}`. Setting include_usage to True adds a message chunk to the end of the stream with usage_metadata populated. In this case the final chunk no longer includes `"finish_reason"` in the `response_metadata`. This is the current default and is not yet released. Because this could be disruptive to workflows, here we remove this default. The default will now be consistent with OpenAI's API (see parameter [here](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)). Examples: ```python from langchain_openai import ChatOpenAI llm = ChatOpenAI() for chunk in llm.stream("hi"): print(chunk) ``` ``` content='' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92' content='Hello' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92' content='!' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92' content='' response_metadata={'finish_reason': 'stop'} id='run-8cff4721-2acd-4551-9bf7-1911dae46b92' ``` ```python for chunk in llm.stream("hi", stream_options={"include_usage": True}): print(chunk) ``` ``` content='' id='run-39ab349b-f954-464d-af6e-72a0927daa27' content='Hello' id='run-39ab349b-f954-464d-af6e-72a0927daa27' content='!' id='run-39ab349b-f954-464d-af6e-72a0927daa27' content='' response_metadata={'finish_reason': 'stop'} id='run-39ab349b-f954-464d-af6e-72a0927daa27' content='' id='run-39ab349b-f954-464d-af6e-72a0927daa27' usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17} ``` ```python llm = ChatOpenAI().bind(stream_options={"include_usage": True}) for chunk in llm.stream("hi"): print(chunk) ``` ``` content='' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' content='Hello' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' content='!' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' content='' response_metadata={'finish_reason': 'stop'} id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' content='' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17} ```
This commit is contained in:
parent
a1899439fc
commit
af1f723ada
@ -1162,29 +1162,6 @@ class ChatOpenAI(BaseChatOpenAI):
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
default_stream_options = {"include_usage": True}
|
||||
stream_options = kwargs.get("stream_options", {})
|
||||
merged_stream_options = {**default_stream_options, **stream_options}
|
||||
kwargs["stream_options"] = merged_stream_options
|
||||
|
||||
return super()._stream(*args, **kwargs)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
"""Set default stream_options."""
|
||||
default_stream_options = {"include_usage": True}
|
||||
stream_options = kwargs.get("stream_options", {})
|
||||
merged_stream_options = {**default_stream_options, **stream_options}
|
||||
kwargs["stream_options"] = merged_stream_options
|
||||
|
||||
async for chunk in super()._astream(*args, **kwargs):
|
||||
yield chunk
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
|
@ -346,10 +346,18 @@ def test_stream() -> None:
|
||||
llm = ChatOpenAI()
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
for chunk in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.response_metadata.get("finish_reason") is not None
|
||||
|
||||
# check token usage
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
for chunk in llm.stream("Hello", stream_options={"include_usage": True}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
@ -359,21 +367,11 @@ def test_stream() -> None:
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
|
||||
# check token usage is populated
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert full.usage_metadata["total_tokens"] > 0
|
||||
|
||||
# check not populated
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
for chunk in llm.stream("Hello", stream_options={"include_usage": False}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(aggregate, AIMessageChunk)
|
||||
assert aggregate.usage_metadata is None
|
||||
assert aggregate.usage_metadata is not None
|
||||
assert aggregate.usage_metadata["input_tokens"] > 0
|
||||
assert aggregate.usage_metadata["output_tokens"] > 0
|
||||
assert aggregate.usage_metadata["total_tokens"] > 0
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
@ -381,10 +379,18 @@ async def test_astream() -> None:
|
||||
llm = ChatOpenAI()
|
||||
|
||||
full: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
async for chunk in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(chunk.content, str)
|
||||
full = chunk if full is None else full + chunk
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.response_metadata.get("finish_reason") is not None
|
||||
|
||||
# check token usage
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
chunks_with_token_counts = 0
|
||||
async for chunk in llm.astream("Hello", stream_options={"include_usage": True}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(chunk, AIMessageChunk)
|
||||
if chunk.usage_metadata is not None:
|
||||
chunks_with_token_counts += 1
|
||||
@ -394,21 +400,11 @@ async def test_astream() -> None:
|
||||
"AIMessageChunk aggregation adds counts. Check that "
|
||||
"this is behaving properly."
|
||||
)
|
||||
|
||||
# check token usage is populated
|
||||
assert isinstance(full, AIMessageChunk)
|
||||
assert full.usage_metadata is not None
|
||||
assert full.usage_metadata["input_tokens"] > 0
|
||||
assert full.usage_metadata["output_tokens"] > 0
|
||||
assert full.usage_metadata["total_tokens"] > 0
|
||||
|
||||
# check not populated
|
||||
aggregate: Optional[BaseMessageChunk] = None
|
||||
async for chunk in llm.astream("Hello", stream_options={"include_usage": False}):
|
||||
assert isinstance(chunk.content, str)
|
||||
aggregate = chunk if aggregate is None else aggregate + chunk
|
||||
assert isinstance(aggregate, AIMessageChunk)
|
||||
assert aggregate.usage_metadata is None
|
||||
assert aggregate.usage_metadata is not None
|
||||
assert aggregate.usage_metadata["input_tokens"] > 0
|
||||
assert aggregate.usage_metadata["output_tokens"] > 0
|
||||
assert aggregate.usage_metadata["total_tokens"] > 0
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user