mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 05:13:46 +00:00
core[patch]: revert change to stream type hint (#31501)
https://github.com/langchain-ai/langchain/pull/31286 included an update to the return type for `BaseChatModel.(a)stream`, from `Iterator[BaseMessageChunk]` to `Iterator[BaseMessage]`. This change is correct, because when streaming is disabled, the stream methods return an iterator of `BaseMessage`, and the inheritance is such that an `BaseMessage` is not a `BaseMessageChunk` (but the reverse is true). However, LangChain includes a pattern throughout its docs of [summing BaseMessageChunks](https://python.langchain.com/docs/how_to/streaming/#llms-and-chat-models) to accumulate a chat model stream. This pattern is implemented in tests for most integration packages and appears in application code. So https://github.com/langchain-ai/langchain/pull/31286 introduces mypy errors throughout the ecosystem (or maybe more accurately, it reveals that this pattern does not account for use of the `.stream` method when streaming is disabled). Here we revert just the change to the stream return type to unblock things. A fix for this should address docs + integration packages (or if we elect to just force people to update code, be explicit about that).
This commit is contained in:
parent
b149cce5f8
commit
741bb1ffa1
@ -51,6 +51,7 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
convert_to_openai_image_block,
|
||||
@ -445,10 +446,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessage]:
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
yield cast(
|
||||
"BaseMessageChunk",
|
||||
self.invoke(input, config=config, stop=stop, **kwargs),
|
||||
)
|
||||
else:
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
@ -533,10 +537,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessage]:
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
||||
# No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||
yield cast(
|
||||
"BaseMessageChunk",
|
||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||
)
|
||||
return
|
||||
|
||||
config = ensure_config(config)
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Tests for verifying that testing utility code works as expected."""
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import cycle
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
@ -117,7 +115,12 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
accumulate_chunks = reduce(operator.add, chunks)
|
||||
accumulate_chunks = None
|
||||
for chunk in chunks:
|
||||
if accumulate_chunks is None:
|
||||
accumulate_chunks = chunk
|
||||
else:
|
||||
accumulate_chunks += chunk
|
||||
|
||||
assert accumulate_chunks == AIMessageChunk(
|
||||
content="",
|
||||
|
@ -163,10 +163,15 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
||||
|
||||
model = ModelWithGenerate()
|
||||
chunks = list(model.stream("anything"))
|
||||
assert chunks == [_any_id_ai_message(content="hello")]
|
||||
# BaseChatModel.stream is typed to return Iterator[BaseMessageChunk].
|
||||
# When streaming is disabled, it returns Iterator[BaseMessage], so the type hint
|
||||
# is not strictly correct.
|
||||
# LangChain documents a pattern of adding BaseMessageChunks to accumulate a stream.
|
||||
# This may be better done with `reduce(operator.add, chunks)`.
|
||||
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
|
||||
|
||||
chunks = [chunk async for chunk in model.astream("anything")]
|
||||
assert chunks == [_any_id_ai_message(content="hello")]
|
||||
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
|
||||
|
||||
|
||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Tests for verifying that testing utility code works as expected."""
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import cycle
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
@ -109,7 +107,12 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
),
|
||||
]
|
||||
|
||||
accumulate_chunks = reduce(operator.add, chunks)
|
||||
accumulate_chunks = None
|
||||
for chunk in chunks:
|
||||
if accumulate_chunks is None:
|
||||
accumulate_chunks = chunk
|
||||
else:
|
||||
accumulate_chunks += chunk
|
||||
|
||||
assert accumulate_chunks == AIMessageChunk(
|
||||
id="a1",
|
||||
|
Loading…
Reference in New Issue
Block a user