diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 5b7165e83a0..f74270a9ac8 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -51,6 +51,7 @@ from langchain_core.messages import ( AIMessage, AnyMessage, BaseMessage, + BaseMessageChunk, HumanMessage, convert_to_messages, convert_to_openai_image_block, @@ -445,10 +446,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): *, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> Iterator[BaseMessage]: + ) -> Iterator[BaseMessageChunk]: if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): # model doesn't implement streaming, so use default implementation - yield self.invoke(input, config=config, stop=stop, **kwargs) + yield cast( + "BaseMessageChunk", + self.invoke(input, config=config, stop=stop, **kwargs), + ) else: config = ensure_config(config) messages = self._convert_input(input).to_messages() @@ -533,10 +537,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): *, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> AsyncIterator[BaseMessage]: + ) -> AsyncIterator[BaseMessageChunk]: if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): # No async or sync stream is implemented, so fall back to ainvoke - yield await self.ainvoke(input, config=config, stop=stop, **kwargs) + yield cast( + "BaseMessageChunk", + await self.ainvoke(input, config=config, stop=stop, **kwargs), + ) return config = ensure_config(config) diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index c39440f5e32..7500e1640ac 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -1,7 +1,5 @@ """Tests for verifying that testing utility code works as expected.""" -import operator -from functools import reduce from itertools import cycle from typing import Any, Optional, Union from uuid import UUID @@ -117,7 +115,12 @@ async def test_generic_fake_chat_model_stream() -> None: ] assert len({chunk.id for chunk in chunks}) == 1 - accumulate_chunks = reduce(operator.add, chunks) + accumulate_chunks = None + for chunk in chunks: + if accumulate_chunks is None: + accumulate_chunks = chunk + else: + accumulate_chunks += chunk assert accumulate_chunks == AIMessageChunk( content="", diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 00e7c8e9f26..cc3037b9177 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -163,10 +163,15 @@ async def test_astream_fallback_to_ainvoke() -> None: model = ModelWithGenerate() chunks = list(model.stream("anything")) - assert chunks == [_any_id_ai_message(content="hello")] + # BaseChatModel.stream is typed to return Iterator[BaseMessageChunk]. + # When streaming is disabled, it returns Iterator[BaseMessage], so the type hint + # is not strictly correct. + # LangChain documents a pattern of adding BaseMessageChunks to accumulate a stream. + # This may be better done with `reduce(operator.add, chunks)`. + assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap] chunks = [chunk async for chunk in model.astream("anything")] - assert chunks == [_any_id_ai_message(content="hello")] + assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap] async def test_astream_implementation_fallback_to_stream() -> None: diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index ed8d47a71f2..a403e3d027f 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -1,7 +1,5 @@ """Tests for verifying that testing utility code works as expected.""" -import operator -from functools import reduce from itertools import cycle from typing import Any, Optional, Union from uuid import UUID @@ -109,7 +107,12 @@ async def test_generic_fake_chat_model_stream() -> None: ), ] - accumulate_chunks = reduce(operator.add, chunks) + accumulate_chunks = None + for chunk in chunks: + if accumulate_chunks is None: + accumulate_chunks = chunk + else: + accumulate_chunks += chunk assert accumulate_chunks == AIMessageChunk( id="a1",