core: In astream_events v2 propagate cancel/break to the inner astream call (#22865)

- previous behavior was for the inner astream to continue running with
no interruption
- also propagate break in core runnable methods
This commit is contained in:
Nuno Campos
2024-06-13 15:02:48 -07:00
committed by GitHub
parent a766815a99
commit bae82e966a
4 changed files with 226 additions and 31 deletions

View File

@@ -1,4 +1,5 @@
"""Module that contains tests for runnable.astream_events API."""
import asyncio
import sys
import uuid
from itertools import cycle
@@ -38,6 +39,7 @@ from langchain_core.runnables import (
RunnableConfig,
RunnableGenerator,
RunnableLambda,
chain,
ensure_config,
)
from langchain_core.runnables.config import get_callback_manager_for_config
@@ -45,6 +47,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool
from langchain_core.utils.aiter import aclosing
from tests.unit_tests.stubs import AnyStr
@@ -2195,3 +2198,140 @@ async def test_with_explicit_config() -> None:
for event in events
if event["event"] == "on_chat_model_stream"
] == ["hello", " ", "world"]
async def test_break_astream_events() -> None:
class AwhileMaker:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
return input
except asyncio.CancelledError:
self.cancelled = True
raise
def reset(self) -> None:
self.started = False
self.cancelled = False
alittlewhile = AwhileMaker()
awhile = AwhileMaker()
anotherwhile = AwhileMaker()
outer_cancelled = False
@chain
async def sequence(input: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
raise
# test interrupting astream_events v2
got_event = False
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
async with aclosing(
sequence.astream_events({"value": 1}, thread2, version="v2")
) as stream:
async for chunk in stream:
if chunk["event"] == "on_chain_stream":
got_event = True
assert chunk["data"]["chunk"] == {"value": 1}
break
# did break
assert got_event
# did cancel outer chain
assert outer_cancelled
# node "alittlewhile" starts, not cancelled
assert alittlewhile.started is True
assert alittlewhile.cancelled is False
# node "awhile" starts but is cancelled
assert awhile.started is True
assert awhile.cancelled is True
# node "anotherwhile" should never start
assert anotherwhile.started is False
async def test_cancel_astream_events() -> None:
class AwhileMaker:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
return input
except asyncio.CancelledError:
self.cancelled = True
raise
def reset(self) -> None:
self.started = False
self.cancelled = False
alittlewhile = AwhileMaker()
awhile = AwhileMaker()
anotherwhile = AwhileMaker()
outer_cancelled = False
@chain
async def sequence(input: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
raise
got_event = False
async def aconsume(stream: AsyncIterator[Any]) -> None:
nonlocal got_event
# here we don't need aclosing as cancelling the task is propagated
# to the async generator being consumed
async for chunk in stream:
if chunk["event"] == "on_chain_stream":
got_event = True
assert chunk["data"]["chunk"] == {"value": 1}
task.cancel()
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
task = asyncio.create_task(
aconsume(sequence.astream_events({"value": 1}, thread2, version="v2"))
)
with pytest.raises(asyncio.CancelledError):
await task
# did break
assert got_event
# did cancel outer chain
assert outer_cancelled
# node "alittlewhile" starts, not cancelled
assert alittlewhile.started is True
assert alittlewhile.cancelled is False
# node "awhile" starts but is cancelled
assert awhile.started is True
assert awhile.cancelled is True
# node "anotherwhile" should never start
assert anotherwhile.started is False