mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 19:39:58 +00:00
core: In astream_events v2 propagate cancel/break to the inner astream call (#22865)
- previous behavior was for the inner astream to continue running with no interruption - also propagate break in core runnable methods
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Module that contains tests for runnable.astream_events API."""
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from itertools import cycle
|
||||
@@ -38,6 +39,7 @@ from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
chain,
|
||||
ensure_config,
|
||||
)
|
||||
from langchain_core.runnables.config import get_callback_manager_for_config
|
||||
@@ -45,6 +47,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.utils.aiter import aclosing
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
@@ -2195,3 +2198,140 @@ async def test_with_explicit_config() -> None:
|
||||
for event in events
|
||||
if event["event"] == "on_chat_model_stream"
|
||||
] == ["hello", " ", "world"]
|
||||
|
||||
|
||||
async def test_break_astream_events() -> None:
|
||||
class AwhileMaker:
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
async def __call__(self, input: Any) -> Any:
|
||||
self.started = True
|
||||
try:
|
||||
await asyncio.sleep(0.5)
|
||||
return input
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.started = False
|
||||
self.cancelled = False
|
||||
|
||||
alittlewhile = AwhileMaker()
|
||||
awhile = AwhileMaker()
|
||||
anotherwhile = AwhileMaker()
|
||||
|
||||
outer_cancelled = False
|
||||
|
||||
@chain
|
||||
async def sequence(input: Any) -> Any:
|
||||
try:
|
||||
yield await alittlewhile(input)
|
||||
yield await awhile(input)
|
||||
yield await anotherwhile(input)
|
||||
except asyncio.CancelledError:
|
||||
nonlocal outer_cancelled
|
||||
outer_cancelled = True
|
||||
raise
|
||||
|
||||
# test interrupting astream_events v2
|
||||
|
||||
got_event = False
|
||||
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
|
||||
async with aclosing(
|
||||
sequence.astream_events({"value": 1}, thread2, version="v2")
|
||||
) as stream:
|
||||
async for chunk in stream:
|
||||
if chunk["event"] == "on_chain_stream":
|
||||
got_event = True
|
||||
assert chunk["data"]["chunk"] == {"value": 1}
|
||||
break
|
||||
|
||||
# did break
|
||||
assert got_event
|
||||
# did cancel outer chain
|
||||
assert outer_cancelled
|
||||
|
||||
# node "alittlewhile" starts, not cancelled
|
||||
assert alittlewhile.started is True
|
||||
assert alittlewhile.cancelled is False
|
||||
|
||||
# node "awhile" starts but is cancelled
|
||||
assert awhile.started is True
|
||||
assert awhile.cancelled is True
|
||||
|
||||
# node "anotherwhile" should never start
|
||||
assert anotherwhile.started is False
|
||||
|
||||
|
||||
async def test_cancel_astream_events() -> None:
|
||||
class AwhileMaker:
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
async def __call__(self, input: Any) -> Any:
|
||||
self.started = True
|
||||
try:
|
||||
await asyncio.sleep(0.5)
|
||||
return input
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.started = False
|
||||
self.cancelled = False
|
||||
|
||||
alittlewhile = AwhileMaker()
|
||||
awhile = AwhileMaker()
|
||||
anotherwhile = AwhileMaker()
|
||||
|
||||
outer_cancelled = False
|
||||
|
||||
@chain
|
||||
async def sequence(input: Any) -> Any:
|
||||
try:
|
||||
yield await alittlewhile(input)
|
||||
yield await awhile(input)
|
||||
yield await anotherwhile(input)
|
||||
except asyncio.CancelledError:
|
||||
nonlocal outer_cancelled
|
||||
outer_cancelled = True
|
||||
raise
|
||||
|
||||
got_event = False
|
||||
|
||||
async def aconsume(stream: AsyncIterator[Any]) -> None:
|
||||
nonlocal got_event
|
||||
# here we don't need aclosing as cancelling the task is propagated
|
||||
# to the async generator being consumed
|
||||
async for chunk in stream:
|
||||
if chunk["event"] == "on_chain_stream":
|
||||
got_event = True
|
||||
assert chunk["data"]["chunk"] == {"value": 1}
|
||||
task.cancel()
|
||||
|
||||
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
|
||||
task = asyncio.create_task(
|
||||
aconsume(sequence.astream_events({"value": 1}, thread2, version="v2"))
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
# did break
|
||||
assert got_event
|
||||
# did cancel outer chain
|
||||
assert outer_cancelled
|
||||
|
||||
# node "alittlewhile" starts, not cancelled
|
||||
assert alittlewhile.started is True
|
||||
assert alittlewhile.cancelled is False
|
||||
|
||||
# node "awhile" starts but is cancelled
|
||||
assert awhile.started is True
|
||||
assert awhile.cancelled is True
|
||||
|
||||
# node "anotherwhile" should never start
|
||||
assert anotherwhile.started is False
|
||||
|
Reference in New Issue
Block a user