mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-14 19:05:21 +00:00
feat(core): add stream_events version='v3' overload
Adds @overload signatures to `Runnable.astream_events` and introduces a new `Runnable.stream_events` sync method, both accepting `version='v3'`. The base-class implementation raises `NotImplementedError` with a message directing callers to use a subclass that implements the v3 streaming protocol (BaseChatModel, CompiledGraph). v1/v2 behavior is unchanged.
This commit is contained in:
@@ -1314,7 +1314,8 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
):
|
||||
yield item
|
||||
|
||||
async def astream_events(
|
||||
@overload
|
||||
def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
@@ -1327,6 +1328,31 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamEvent]: ...
|
||||
|
||||
@overload
|
||||
def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
version: Literal["v3"],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Any]: ...
|
||||
|
||||
async def astream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
version: Literal["v1", "v2", "v3"] = "v2",
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Generate a stream of events.
|
||||
|
||||
@@ -1521,10 +1547,19 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
An async stream of `StreamEvent`.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the version is not `'v1'` or `'v2'`.
|
||||
NotImplementedError: If the version is not `'v1'`, `'v2'`, or `'v3'`, or
|
||||
if `'v3'` is requested on a `Runnable` that does not implement the v3
|
||||
streaming protocol.
|
||||
|
||||
""" # noqa: E501
|
||||
if version == "v2":
|
||||
if version == "v3":
|
||||
msg = (
|
||||
"astream_events(version='v3') is only supported on Runnable subclasses "
|
||||
"that implement the v3 streaming protocol (BaseChatModel, CompiledGraph). "
|
||||
f"Got: {type(self).__name__}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
elif version == "v2":
|
||||
event_stream = _astream_events_implementation_v2(
|
||||
self,
|
||||
input,
|
||||
@@ -1553,13 +1588,100 @@ class Runnable(ABC, Generic[Input, Output]):
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
msg = 'Only versions "v1" and "v2" of the schema is currently supported.'
|
||||
msg = f"Unsupported version: {version!r}. Expected 'v1', 'v2', or 'v3'."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
async with aclosing(event_stream):
|
||||
async for event in event_stream:
|
||||
yield event
|
||||
|
||||
@overload
|
||||
def stream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
version: Literal["v1", "v2"] = "v2",
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[StreamEvent]: ...
|
||||
|
||||
@overload
|
||||
def stream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
version: Literal["v3"],
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Any]: ...
|
||||
|
||||
def stream_events(
|
||||
self,
|
||||
input: Any,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
version: Literal["v1", "v2", "v3"] = "v2",
|
||||
include_names: Sequence[str] | None = None,
|
||||
include_types: Sequence[str] | None = None,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_names: Sequence[str] | None = None,
|
||||
exclude_types: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[StreamEvent]:
|
||||
"""Generate a stream of events synchronously.
|
||||
|
||||
Synchronous counterpart to `astream_events`. For `version='v3'`, subclasses
|
||||
that implement the v3 streaming protocol (`BaseChatModel`, `CompiledGraph`)
|
||||
override this method. All other versions and base-class calls raise
|
||||
`NotImplementedError`.
|
||||
|
||||
Args:
|
||||
input: The input to the `Runnable`.
|
||||
config: The config to use for the `Runnable`.
|
||||
version: The version of the schema to use. `'v3'` requires a subclass
|
||||
that implements the v3 streaming protocol. `'v1'` and `'v2'` are not
|
||||
supported on the sync path.
|
||||
include_names: Only include events from `Runnable` objects with matching
|
||||
names.
|
||||
include_types: Only include events from `Runnable` objects with matching
|
||||
types.
|
||||
include_tags: Only include events from `Runnable` objects with matching
|
||||
tags.
|
||||
exclude_names: Exclude events from `Runnable` objects with matching names.
|
||||
exclude_types: Exclude events from `Runnable` objects with matching types.
|
||||
exclude_tags: Exclude events from `Runnable` objects with matching tags.
|
||||
**kwargs: Additional keyword arguments to pass to the `Runnable`.
|
||||
|
||||
Yields:
|
||||
A stream of events.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Always. Subclasses override this method for supported
|
||||
versions.
|
||||
|
||||
"""
|
||||
if version == "v3":
|
||||
msg = (
|
||||
"stream_events(version='v3') is only supported on Runnable subclasses "
|
||||
"that implement the v3 streaming protocol (BaseChatModel, CompiledGraph). "
|
||||
f"Got: {type(self).__name__}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
msg = (
|
||||
f"stream_events(version={version!r}) is not supported. "
|
||||
"Use astream_events() for v1/v2, or stream_events(version='v3') "
|
||||
"on a supported subclass."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Tests for the v3 dispatch path on Runnable.astream_events / stream_events."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
|
||||
|
||||
def _double(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
|
||||
def test_v3_on_plain_runnable_raises_not_implemented_sync() -> None:
|
||||
runnable = RunnableLambda(_double)
|
||||
with pytest.raises(NotImplementedError, match="v3"):
|
||||
runnable.stream_events(2, version="v3")
|
||||
|
||||
|
||||
async def test_v3_on_plain_runnable_raises_not_implemented_async() -> None:
|
||||
runnable = RunnableLambda(_double)
|
||||
with pytest.raises(NotImplementedError, match="v3"):
|
||||
async for _ in runnable.astream_events(2, version="v3"):
|
||||
pass
|
||||
Reference in New Issue
Block a user