From 741bb1ffa1cc7a41565a49d5d0934a8187b34b27 Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 5 Jun 2025 11:20:06 -0400 Subject: [PATCH] core[patch]: revert change to stream type hint (#31501) https://github.com/langchain-ai/langchain/pull/31286 included an update to the return type for `BaseChatModel.(a)stream`, from `Iterator[BaseMessageChunk]` to `Iterator[BaseMessage]`. This change is correct, because when streaming is disabled, the stream methods return an iterator of `BaseMessage`, and the inheritance is such that an `BaseMessage` is not a `BaseMessageChunk` (but the reverse is true). However, LangChain includes a pattern throughout its docs of [summing BaseMessageChunks](https://python.langchain.com/docs/how_to/streaming/#llms-and-chat-models) to accumulate a chat model stream. This pattern is implemented in tests for most integration packages and appears in application code. So https://github.com/langchain-ai/langchain/pull/31286 introduces mypy errors throughout the ecosystem (or maybe more accurately, it reveals that this pattern does not account for use of the `.stream` method when streaming is disabled). Here we revert just the change to the stream return type to unblock things. A fix for this should address docs + integration packages (or if we elect to just force people to update code, be explicit about that). --- .../langchain_core/language_models/chat_models.py | 15 +++++++++++---- .../tests/unit_tests/fake/test_fake_chat_model.py | 9 ++++++--- .../language_models/chat_models/test_base.py | 9 +++++++-- .../tests/unit_tests/llms/test_fake_chat_model.py | 9 ++++++--- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 5b7165e83a0..f74270a9ac8 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -51,6 +51,7 @@ from langchain_core.messages import ( AIMessage, AnyMessage, BaseMessage, + BaseMessageChunk, HumanMessage, convert_to_messages, convert_to_openai_image_block, @@ -445,10 +446,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): *, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> Iterator[BaseMessage]: + ) -> Iterator[BaseMessageChunk]: if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): # model doesn't implement streaming, so use default implementation - yield self.invoke(input, config=config, stop=stop, **kwargs) + yield cast( + "BaseMessageChunk", + self.invoke(input, config=config, stop=stop, **kwargs), + ) else: config = ensure_config(config) messages = self._convert_input(input).to_messages() @@ -533,10 +537,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): *, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> AsyncIterator[BaseMessage]: + ) -> AsyncIterator[BaseMessageChunk]: if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): # No async or sync stream is implemented, so fall back to ainvoke - yield await self.ainvoke(input, config=config, stop=stop, **kwargs) + yield cast( + "BaseMessageChunk", + await self.ainvoke(input, config=config, stop=stop, **kwargs), + ) return config = ensure_config(config) diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index c39440f5e32..7500e1640ac 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -1,7 +1,5 @@ """Tests for verifying that testing utility code works as expected.""" -import operator -from functools import reduce from itertools import cycle from typing import Any, Optional, Union from uuid import UUID @@ -117,7 +115,12 @@ async def test_generic_fake_chat_model_stream() -> None: ] assert len({chunk.id for chunk in chunks}) == 1 - accumulate_chunks = reduce(operator.add, chunks) + accumulate_chunks = None + for chunk in chunks: + if accumulate_chunks is None: + accumulate_chunks = chunk + else: + accumulate_chunks += chunk assert accumulate_chunks == AIMessageChunk( content="", diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 00e7c8e9f26..cc3037b9177 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -163,10 +163,15 @@ async def test_astream_fallback_to_ainvoke() -> None: model = ModelWithGenerate() chunks = list(model.stream("anything")) - assert chunks == [_any_id_ai_message(content="hello")] + # BaseChatModel.stream is typed to return Iterator[BaseMessageChunk]. + # When streaming is disabled, it returns Iterator[BaseMessage], so the type hint + # is not strictly correct. + # LangChain documents a pattern of adding BaseMessageChunks to accumulate a stream. + # This may be better done with `reduce(operator.add, chunks)`. + assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap] chunks = [chunk async for chunk in model.astream("anything")] - assert chunks == [_any_id_ai_message(content="hello")] + assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap] async def test_astream_implementation_fallback_to_stream() -> None: diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index ed8d47a71f2..a403e3d027f 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -1,7 +1,5 @@ """Tests for verifying that testing utility code works as expected.""" -import operator -from functools import reduce from itertools import cycle from typing import Any, Optional, Union from uuid import UUID @@ -109,7 +107,12 @@ async def test_generic_fake_chat_model_stream() -> None: ), ] - accumulate_chunks = reduce(operator.add, chunks) + accumulate_chunks = None + for chunk in chunks: + if accumulate_chunks is None: + accumulate_chunks = chunk + else: + accumulate_chunks += chunk assert accumulate_chunks == AIMessageChunk( id="a1",