From af1f723ada03edbabb45c5efe51d209caa60f3c3 Mon Sep 17 00:00:00 2001 From: ccurme Date: Wed, 29 May 2024 10:30:40 -0400 Subject: [PATCH] openai: don't override stream_options default (#22242) ChatOpenAI supports a kwarg `stream_options` which can take values `{"include_usage": True}` and `{"include_usage": False}`. Setting include_usage to True adds a message chunk to the end of the stream with usage_metadata populated. In this case the final chunk no longer includes `"finish_reason"` in the `response_metadata`. This is the current default and is not yet released. Because this could be disruptive to workflows, here we remove this default. The default will now be consistent with OpenAI's API (see parameter [here](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)). Examples: ```python from langchain_openai import ChatOpenAI llm = ChatOpenAI() for chunk in llm.stream("hi"): print(chunk) ``` ``` content='' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92' content='Hello' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92' content='!' id='run-8cff4721-2acd-4551-9bf7-1911dae46b92' content='' response_metadata={'finish_reason': 'stop'} id='run-8cff4721-2acd-4551-9bf7-1911dae46b92' ``` ```python for chunk in llm.stream("hi", stream_options={"include_usage": True}): print(chunk) ``` ``` content='' id='run-39ab349b-f954-464d-af6e-72a0927daa27' content='Hello' id='run-39ab349b-f954-464d-af6e-72a0927daa27' content='!' id='run-39ab349b-f954-464d-af6e-72a0927daa27' content='' response_metadata={'finish_reason': 'stop'} id='run-39ab349b-f954-464d-af6e-72a0927daa27' content='' id='run-39ab349b-f954-464d-af6e-72a0927daa27' usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17} ``` ```python llm = ChatOpenAI().bind(stream_options={"include_usage": True}) for chunk in llm.stream("hi"): print(chunk) ``` ``` content='' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' content='Hello' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' content='!' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' content='' response_metadata={'finish_reason': 'stop'} id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' content='' id='run-59918845-04b2-41a6-8d90-f75fb4506e0d' usage_metadata={'input_tokens': 8, 'output_tokens': 9, 'total_tokens': 17} ``` --- .../langchain_openai/chat_models/base.py | 23 -------- .../chat_models/test_base.py | 56 +++++++++---------- 2 files changed, 26 insertions(+), 53 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 84459cc4d29..7133956bb40 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1162,29 +1162,6 @@ class ChatOpenAI(BaseChatOpenAI): """Return whether this model can be serialized by Langchain.""" return True - def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: - """Set default stream_options.""" - default_stream_options = {"include_usage": True} - stream_options = kwargs.get("stream_options", {}) - merged_stream_options = {**default_stream_options, **stream_options} - kwargs["stream_options"] = merged_stream_options - - return super()._stream(*args, **kwargs) - - async def _astream( - self, - *args: Any, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - """Set default stream_options.""" - default_stream_options = {"include_usage": True} - stream_options = kwargs.get("stream_options", {}) - merged_stream_options = {**default_stream_options, **stream_options} - kwargs["stream_options"] = merged_stream_options - - async for chunk in super()._astream(*args, **kwargs): - yield chunk - def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and issubclass(obj, BaseModel) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index db87949eb64..5a31bc76f9d 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -346,10 +346,18 @@ def test_stream() -> None: llm = ChatOpenAI() full: Optional[BaseMessageChunk] = None - chunks_with_token_counts = 0 for chunk in llm.stream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.response_metadata.get("finish_reason") is not None + + # check token usage + aggregate: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 + for chunk in llm.stream("Hello", stream_options={"include_usage": True}): + assert isinstance(chunk.content, str) + aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(chunk, AIMessageChunk) if chunk.usage_metadata is not None: chunks_with_token_counts += 1 @@ -359,21 +367,11 @@ def test_stream() -> None: "AIMessageChunk aggregation adds counts. Check that " "this is behaving properly." ) - - # check token usage is populated - assert isinstance(full, AIMessageChunk) - assert full.usage_metadata is not None - assert full.usage_metadata["input_tokens"] > 0 - assert full.usage_metadata["output_tokens"] > 0 - assert full.usage_metadata["total_tokens"] > 0 - - # check not populated - aggregate: Optional[BaseMessageChunk] = None - for chunk in llm.stream("Hello", stream_options={"include_usage": False}): - assert isinstance(chunk.content, str) - aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(aggregate, AIMessageChunk) - assert aggregate.usage_metadata is None + assert aggregate.usage_metadata is not None + assert aggregate.usage_metadata["input_tokens"] > 0 + assert aggregate.usage_metadata["output_tokens"] > 0 + assert aggregate.usage_metadata["total_tokens"] > 0 async def test_astream() -> None: @@ -381,10 +379,18 @@ async def test_astream() -> None: llm = ChatOpenAI() full: Optional[BaseMessageChunk] = None - chunks_with_token_counts = 0 async for chunk in llm.astream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.response_metadata.get("finish_reason") is not None + + # check token usage + aggregate: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 + async for chunk in llm.astream("Hello", stream_options={"include_usage": True}): + assert isinstance(chunk.content, str) + aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(chunk, AIMessageChunk) if chunk.usage_metadata is not None: chunks_with_token_counts += 1 @@ -394,21 +400,11 @@ async def test_astream() -> None: "AIMessageChunk aggregation adds counts. Check that " "this is behaving properly." ) - - # check token usage is populated - assert isinstance(full, AIMessageChunk) - assert full.usage_metadata is not None - assert full.usage_metadata["input_tokens"] > 0 - assert full.usage_metadata["output_tokens"] > 0 - assert full.usage_metadata["total_tokens"] > 0 - - # check not populated - aggregate: Optional[BaseMessageChunk] = None - async for chunk in llm.astream("Hello", stream_options={"include_usage": False}): - assert isinstance(chunk.content, str) - aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(aggregate, AIMessageChunk) - assert aggregate.usage_metadata is None + assert aggregate.usage_metadata is not None + assert aggregate.usage_metadata["input_tokens"] > 0 + assert aggregate.usage_metadata["output_tokens"] > 0 + assert aggregate.usage_metadata["total_tokens"] > 0 async def test_abatch() -> None: