core[patch]: revert change to stream type hint (#31501)

https://github.com/langchain-ai/langchain/pull/31286 included an update
to the return type for `BaseChatModel.(a)stream`, from
`Iterator[BaseMessageChunk]` to `Iterator[BaseMessage]`.

This change is correct, because when streaming is disabled, the stream
methods return an iterator of `BaseMessage`, and the inheritance is such
that an `BaseMessage` is not a `BaseMessageChunk` (but the reverse is
true).

However, LangChain includes a pattern throughout its docs of [summing
BaseMessageChunks](https://python.langchain.com/docs/how_to/streaming/#llms-and-chat-models)
to accumulate a chat model stream. This pattern is implemented in tests
for most integration packages and appears in application code. So
https://github.com/langchain-ai/langchain/pull/31286 introduces mypy
errors throughout the ecosystem (or maybe more accurately, it reveals
that this pattern does not account for use of the `.stream` method when
streaming is disabled).

Here we revert just the change to the stream return type to unblock
things. A fix for this should address docs + integration packages (or if
we elect to just force people to update code, be explicit about that).
This commit is contained in:
ccurme 2025-06-05 11:20:06 -04:00 committed by GitHub
parent b149cce5f8
commit 741bb1ffa1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 12 deletions

View File

@ -51,6 +51,7 @@ from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
convert_to_messages,
convert_to_openai_image_block,
@ -445,10 +446,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessage]:
) -> Iterator[BaseMessageChunk]:
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
yield cast(
"BaseMessageChunk",
self.invoke(input, config=config, stop=stop, **kwargs),
)
else:
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
@ -533,10 +537,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessage]:
) -> AsyncIterator[BaseMessageChunk]:
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
# No async or sync stream is implemented, so fall back to ainvoke
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
yield cast(
"BaseMessageChunk",
await self.ainvoke(input, config=config, stop=stop, **kwargs),
)
return
config = ensure_config(config)

View File

@ -1,7 +1,5 @@
"""Tests for verifying that testing utility code works as expected."""
import operator
from functools import reduce
from itertools import cycle
from typing import Any, Optional, Union
from uuid import UUID
@ -117,7 +115,12 @@ async def test_generic_fake_chat_model_stream() -> None:
]
assert len({chunk.id for chunk in chunks}) == 1
accumulate_chunks = reduce(operator.add, chunks)
accumulate_chunks = None
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk(
content="",

View File

@ -163,10 +163,15 @@ async def test_astream_fallback_to_ainvoke() -> None:
model = ModelWithGenerate()
chunks = list(model.stream("anything"))
assert chunks == [_any_id_ai_message(content="hello")]
# BaseChatModel.stream is typed to return Iterator[BaseMessageChunk].
# When streaming is disabled, it returns Iterator[BaseMessage], so the type hint
# is not strictly correct.
# LangChain documents a pattern of adding BaseMessageChunks to accumulate a stream.
# This may be better done with `reduce(operator.add, chunks)`.
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == [_any_id_ai_message(content="hello")]
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
async def test_astream_implementation_fallback_to_stream() -> None:

View File

@ -1,7 +1,5 @@
"""Tests for verifying that testing utility code works as expected."""
import operator
from functools import reduce
from itertools import cycle
from typing import Any, Optional, Union
from uuid import UUID
@ -109,7 +107,12 @@ async def test_generic_fake_chat_model_stream() -> None:
),
]
accumulate_chunks = reduce(operator.add, chunks)
accumulate_chunks = None
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk(
id="a1",