mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
core[patch]: revert change to stream type hint (#31501)
https://github.com/langchain-ai/langchain/pull/31286 included an update to the return type for `BaseChatModel.(a)stream`, from `Iterator[BaseMessageChunk]` to `Iterator[BaseMessage]`. This change is correct, because when streaming is disabled, the stream methods return an iterator of `BaseMessage`, and the inheritance is such that an `BaseMessage` is not a `BaseMessageChunk` (but the reverse is true). However, LangChain includes a pattern throughout its docs of [summing BaseMessageChunks](https://python.langchain.com/docs/how_to/streaming/#llms-and-chat-models) to accumulate a chat model stream. This pattern is implemented in tests for most integration packages and appears in application code. So https://github.com/langchain-ai/langchain/pull/31286 introduces mypy errors throughout the ecosystem (or maybe more accurately, it reveals that this pattern does not account for use of the `.stream` method when streaming is disabled). Here we revert just the change to the stream return type to unblock things. A fix for this should address docs + integration packages (or if we elect to just force people to update code, be explicit about that).
This commit is contained in:
parent
b149cce5f8
commit
741bb1ffa1
@ -51,6 +51,7 @@ from langchain_core.messages import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
convert_to_messages,
|
convert_to_messages,
|
||||||
convert_to_openai_image_block,
|
convert_to_openai_image_block,
|
||||||
@ -445,10 +446,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
*,
|
*,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[BaseMessage]:
|
) -> Iterator[BaseMessageChunk]:
|
||||||
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
yield cast(
|
||||||
|
"BaseMessageChunk",
|
||||||
|
self.invoke(input, config=config, stop=stop, **kwargs),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
@ -533,10 +537,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
*,
|
*,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[BaseMessage]:
|
) -> AsyncIterator[BaseMessageChunk]:
|
||||||
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
||||||
# No async or sync stream is implemented, so fall back to ainvoke
|
# No async or sync stream is implemented, so fall back to ainvoke
|
||||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
yield cast(
|
||||||
|
"BaseMessageChunk",
|
||||||
|
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
"""Tests for verifying that testing utility code works as expected."""
|
"""Tests for verifying that testing utility code works as expected."""
|
||||||
|
|
||||||
import operator
|
|
||||||
from functools import reduce
|
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -117,7 +115,12 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
accumulate_chunks = reduce(operator.add, chunks)
|
accumulate_chunks = None
|
||||||
|
for chunk in chunks:
|
||||||
|
if accumulate_chunks is None:
|
||||||
|
accumulate_chunks = chunk
|
||||||
|
else:
|
||||||
|
accumulate_chunks += chunk
|
||||||
|
|
||||||
assert accumulate_chunks == AIMessageChunk(
|
assert accumulate_chunks == AIMessageChunk(
|
||||||
content="",
|
content="",
|
||||||
|
@ -163,10 +163,15 @@ async def test_astream_fallback_to_ainvoke() -> None:
|
|||||||
|
|
||||||
model = ModelWithGenerate()
|
model = ModelWithGenerate()
|
||||||
chunks = list(model.stream("anything"))
|
chunks = list(model.stream("anything"))
|
||||||
assert chunks == [_any_id_ai_message(content="hello")]
|
# BaseChatModel.stream is typed to return Iterator[BaseMessageChunk].
|
||||||
|
# When streaming is disabled, it returns Iterator[BaseMessage], so the type hint
|
||||||
|
# is not strictly correct.
|
||||||
|
# LangChain documents a pattern of adding BaseMessageChunks to accumulate a stream.
|
||||||
|
# This may be better done with `reduce(operator.add, chunks)`.
|
||||||
|
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
|
||||||
|
|
||||||
chunks = [chunk async for chunk in model.astream("anything")]
|
chunks = [chunk async for chunk in model.astream("anything")]
|
||||||
assert chunks == [_any_id_ai_message(content="hello")]
|
assert chunks == [_any_id_ai_message(content="hello")] # type: ignore[comparison-overlap]
|
||||||
|
|
||||||
|
|
||||||
async def test_astream_implementation_fallback_to_stream() -> None:
|
async def test_astream_implementation_fallback_to_stream() -> None:
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
"""Tests for verifying that testing utility code works as expected."""
|
"""Tests for verifying that testing utility code works as expected."""
|
||||||
|
|
||||||
import operator
|
|
||||||
from functools import reduce
|
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -109,7 +107,12 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
accumulate_chunks = reduce(operator.add, chunks)
|
accumulate_chunks = None
|
||||||
|
for chunk in chunks:
|
||||||
|
if accumulate_chunks is None:
|
||||||
|
accumulate_chunks = chunk
|
||||||
|
else:
|
||||||
|
accumulate_chunks += chunk
|
||||||
|
|
||||||
assert accumulate_chunks == AIMessageChunk(
|
assert accumulate_chunks == AIMessageChunk(
|
||||||
id="a1",
|
id="a1",
|
||||||
|
Loading…
Reference in New Issue
Block a user