diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 84459cc4d29..7133956bb40 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1162,29 +1162,6 @@ class ChatOpenAI(BaseChatOpenAI): """Return whether this model can be serialized by Langchain.""" return True - def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]: - """Set default stream_options.""" - default_stream_options = {"include_usage": True} - stream_options = kwargs.get("stream_options", {}) - merged_stream_options = {**default_stream_options, **stream_options} - kwargs["stream_options"] = merged_stream_options - - return super()._stream(*args, **kwargs) - - async def _astream( - self, - *args: Any, - **kwargs: Any, - ) -> AsyncIterator[ChatGenerationChunk]: - """Set default stream_options.""" - default_stream_options = {"include_usage": True} - stream_options = kwargs.get("stream_options", {}) - merged_stream_options = {**default_stream_options, **stream_options} - kwargs["stream_options"] = merged_stream_options - - async for chunk in super()._astream(*args, **kwargs): - yield chunk - def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and issubclass(obj, BaseModel) diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index db87949eb64..5a31bc76f9d 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -346,10 +346,18 @@ def test_stream() -> None: llm = ChatOpenAI() full: Optional[BaseMessageChunk] = None - chunks_with_token_counts = 0 for chunk in llm.stream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.response_metadata.get("finish_reason") is not None + + # check token usage + aggregate: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 + for chunk in llm.stream("Hello", stream_options={"include_usage": True}): + assert isinstance(chunk.content, str) + aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(chunk, AIMessageChunk) if chunk.usage_metadata is not None: chunks_with_token_counts += 1 @@ -359,21 +367,11 @@ def test_stream() -> None: "AIMessageChunk aggregation adds counts. Check that " "this is behaving properly." ) - - # check token usage is populated - assert isinstance(full, AIMessageChunk) - assert full.usage_metadata is not None - assert full.usage_metadata["input_tokens"] > 0 - assert full.usage_metadata["output_tokens"] > 0 - assert full.usage_metadata["total_tokens"] > 0 - - # check not populated - aggregate: Optional[BaseMessageChunk] = None - for chunk in llm.stream("Hello", stream_options={"include_usage": False}): - assert isinstance(chunk.content, str) - aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(aggregate, AIMessageChunk) - assert aggregate.usage_metadata is None + assert aggregate.usage_metadata is not None + assert aggregate.usage_metadata["input_tokens"] > 0 + assert aggregate.usage_metadata["output_tokens"] > 0 + assert aggregate.usage_metadata["total_tokens"] > 0 async def test_astream() -> None: @@ -381,10 +379,18 @@ async def test_astream() -> None: llm = ChatOpenAI() full: Optional[BaseMessageChunk] = None - chunks_with_token_counts = 0 async for chunk in llm.astream("I'm Pickle Rick"): assert isinstance(chunk.content, str) full = chunk if full is None else full + chunk + assert isinstance(full, AIMessageChunk) + assert full.response_metadata.get("finish_reason") is not None + + # check token usage + aggregate: Optional[BaseMessageChunk] = None + chunks_with_token_counts = 0 + async for chunk in llm.astream("Hello", stream_options={"include_usage": True}): + assert isinstance(chunk.content, str) + aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(chunk, AIMessageChunk) if chunk.usage_metadata is not None: chunks_with_token_counts += 1 @@ -394,21 +400,11 @@ async def test_astream() -> None: "AIMessageChunk aggregation adds counts. Check that " "this is behaving properly." ) - - # check token usage is populated - assert isinstance(full, AIMessageChunk) - assert full.usage_metadata is not None - assert full.usage_metadata["input_tokens"] > 0 - assert full.usage_metadata["output_tokens"] > 0 - assert full.usage_metadata["total_tokens"] > 0 - - # check not populated - aggregate: Optional[BaseMessageChunk] = None - async for chunk in llm.astream("Hello", stream_options={"include_usage": False}): - assert isinstance(chunk.content, str) - aggregate = chunk if aggregate is None else aggregate + chunk assert isinstance(aggregate, AIMessageChunk) - assert aggregate.usage_metadata is None + assert aggregate.usage_metadata is not None + assert aggregate.usage_metadata["input_tokens"] > 0 + assert aggregate.usage_metadata["output_tokens"] > 0 + assert aggregate.usage_metadata["total_tokens"] > 0 async def test_abatch() -> None: