core[patch]: revert change to stream type hint (#31501)

https://github.com/langchain-ai/langchain/pull/31286 included an update
to the return type for `BaseChatModel.(a)stream`, from
`Iterator[BaseMessageChunk]` to `Iterator[BaseMessage]`.

This change is correct, because when streaming is disabled, the stream
methods return an iterator of `BaseMessage`, and the inheritance is such
that an `BaseMessage` is not a `BaseMessageChunk` (but the reverse is
true).

However, LangChain includes a pattern throughout its docs of [summing
BaseMessageChunks](https://python.langchain.com/docs/how_to/streaming/#llms-and-chat-models)
to accumulate a chat model stream. This pattern is implemented in tests
for most integration packages and appears in application code. So
https://github.com/langchain-ai/langchain/pull/31286 introduces mypy
errors throughout the ecosystem (or maybe more accurately, it reveals
that this pattern does not account for use of the `.stream` method when
streaming is disabled).

Here we revert just the change to the stream return type to unblock
things. A fix for this should address docs + integration packages (or if
we elect to just force people to update code, be explicit about that).
This commit is contained in:
ccurme 2025-06-05 11:20:06 -04:00 committed by GitHub
parent b149cce5f8
commit 741bb1ffa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 12 deletions

View File

@ -51,6 +51,7 @@ from langchain_core.messages import (
AIMessage, AIMessage,
AnyMessage, AnyMessage,
BaseMessage, BaseMessage,
BaseMessageChunk,
HumanMessage, HumanMessage,
convert_to_messages, convert_to_messages,
convert_to_openai_image_block, convert_to_openai_image_block,
@ -445,10 +446,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*, *,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[BaseMessage]: ) -> Iterator[BaseMessageChunk]:
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
# model doesn't implement streaming, so use default implementation # model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs) yield cast(
"BaseMessageChunk",
self.invoke(input, config=config, stop=stop, **kwargs),
)
else: else:
config = ensure_config(config) config = ensure_config(config)
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()
@ -533,10 +537,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*, *,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[BaseMessage]: ) -> AsyncIterator[BaseMessageChunk]:
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
# No async or sync stream is implemented, so fall back to ainvoke # No async or sync stream is implemented, so fall back to ainvoke
yield await self.ainvoke(input, config=config, stop=stop, **kwargs) yield cast(
"BaseMessageChunk",
await self.ainvoke(input, config=config, stop=stop, **kwargs),
)
return return
config = ensure_config(config) config = ensure_config(config)

View File

@ -1,7 +1,5 @@
"""Tests for verifying that testing utility code works as expected.""" """Tests for verifying that testing utility code works as expected."""
import operator
from functools import reduce
from itertools import cycle from itertools import cycle
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
@ -117,7 +115,12 @@ async def test_generic_fake_chat_model_stream() -> None:
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1
accumulate_chunks = reduce(operator.add, chunks) accumulate_chunks = None
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk( assert accumulate_chunks == AIMessageChunk(
content="", content="",

View File

@ -163,10 +163,15 @@ async def test_astream_fallback_to_ainvoke() -> None:
model = ModelWithGenerate() model = ModelWithGenerate()
chunks = list(model.stream("anything")) chunks = list(model.stream("anything"))
assert chunks == [_any_id_ai_message(content="hello")] # BaseChatModel.stream is typed to return Iterator[BaseMessageChunk].
# When streaming is disabled, it returns Iterator[BaseMessage], so the type hint
# is not strictly correct.
# LangChain documents a pattern of adding BaseMessageChunks to accumulate a stream.
# This may be better done with `reduce(operator.add, chunks)`.
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
chunks = [chunk async for chunk in model.astream("anything")] chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [_any_id_ai_message(content="hello")] assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
async def test_astream_implementation_fallback_to_stream() -> None: async def test_astream_implementation_fallback_to_stream() -> None:

View File

@ -1,7 +1,5 @@
"""Tests for verifying that testing utility code works as expected.""" """Tests for verifying that testing utility code works as expected."""
import operator
from functools import reduce
from itertools import cycle from itertools import cycle
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
@ -109,7 +107,12 @@ async def test_generic_fake_chat_model_stream() -> None:
), ),
] ]
accumulate_chunks = reduce(operator.add, chunks) accumulate_chunks = None
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk( assert accumulate_chunks == AIMessageChunk(
id="a1", id="a1",