core: In astream_events v2 propagate cancel/break to the inner astream call (#22865)

- previous behavior was for the inner astream to continue running with
no interruption
- also propagate break in core runnable methods
This commit is contained in:
Nuno Campos 2024-06-13 15:02:48 -07:00 committed by GitHub
parent a766815a99
commit bae82e966a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 226 additions and 31 deletions

View File

@ -13,6 +13,7 @@ from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncGenerator,
AsyncIterator, AsyncIterator,
Awaitable, Awaitable,
Callable, Callable,
@ -79,7 +80,7 @@ from langchain_core.runnables.utils import (
is_async_callable, is_async_callable,
is_async_generator, is_async_generator,
) )
from langchain_core.utils.aiter import atee, py_anext from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee from langchain_core.utils.iter import safetee
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1141,8 +1142,9 @@ class Runnable(Generic[Input, Output], ABC):
'Only versions "v1" and "v2" of the schema is currently supported.' 'Only versions "v1" and "v2" of the schema is currently supported.'
) )
async for event in event_stream: async with aclosing(event_stream):
yield event async for event in event_stream:
yield event
def transform( def transform(
self, self,
@ -1948,7 +1950,7 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["run_manager"] = run_manager kwargs["run_manager"] = run_manager
context = copy_context() context = copy_context()
context.run(_set_config_context, child_config) context.run(_set_config_context, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next( if stream_handler := next(
( (
@ -1960,7 +1962,11 @@ class Runnable(Generic[Input, Output], ABC):
None, None,
): ):
# populates streamed_output in astream_log() output if needed # populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator) iterator = stream_handler.tap_output_aiter(
run_manager.run_id, iterator_
)
else:
iterator = iterator_
try: try:
while True: while True:
if accepts_context(asyncio.create_task): if accepts_context(asyncio.create_task):
@ -2001,6 +2007,9 @@ class Runnable(Generic[Input, Output], ABC):
raise raise
else: else:
await run_manager.on_chain_end(final_output, inputs=final_input) await run_manager.on_chain_end(final_output, inputs=final_input)
finally:
if hasattr(iterator_, "aclose"):
await iterator_.aclose()
class RunnableSerializable(Serializable, Runnable[Input, Output]): class RunnableSerializable(Serializable, Runnable[Input, Output]):
@ -3907,23 +3916,29 @@ class RunnableLambda(Runnable[Input, Output]):
if is_async_generator(afunc): if is_async_generator(afunc):
output: Optional[Output] = None output: Optional[Output] = None
async for chunk in cast( async with aclosing(
AsyncIterator[Output], cast(
acall_func_with_variable_args( AsyncGenerator[Any, Any],
cast(Callable, afunc), acall_func_with_variable_args(
input, cast(Callable, afunc),
config, input,
run_manager, config,
**kwargs, run_manager,
), **kwargs,
): ),
if output is None: )
output = chunk ) as stream:
else: async for chunk in cast(
try: AsyncIterator[Output],
output = output + chunk # type: ignore[operator] stream,
except TypeError: ):
if output is None:
output = chunk output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
output = chunk
else: else:
output = await acall_func_with_variable_args( output = await acall_func_with_variable_args(
cast(Callable, afunc), input, config, run_manager, **kwargs cast(Callable, afunc), input, config, run_manager, **kwargs

View File

@ -37,7 +37,7 @@ from langchain_core.runnables.utils import (
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.log_stream import LogEntry from langchain_core.tracers.log_stream import LogEntry
from langchain_core.tracers.memory_stream import _MemoryStream from langchain_core.tracers.memory_stream import _MemoryStream
from langchain_core.utils.aiter import py_anext from langchain_core.utils.aiter import aclosing, py_anext
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.documents import Document from langchain_core.documents import Document
@ -903,11 +903,10 @@ async def _astream_events_implementation_v2(
async def consume_astream() -> None: async def consume_astream() -> None:
try: try:
# if astream also calls tap_output_aiter this will be a no-op # if astream also calls tap_output_aiter this will be a no-op
async for _ in event_streamer.tap_output_aiter( async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
run_id, runnable.astream(input, config, **kwargs) async for _ in event_streamer.tap_output_aiter(run_id, stream):
): # All the content will be picked up
# All the content will be picked up pass
pass
finally: finally:
await event_streamer.send_stream.aclose() await event_streamer.send_stream.aclose()
@ -942,7 +941,8 @@ async def _astream_events_implementation_v2(
yield event yield event
finally: finally:
# Wait for the runnable to finish, if not cancelled (eg. by break) # Wait for the runnable to finish, if not cancelled (eg. by break)
try: if task.cancel():
await task try:
except asyncio.CancelledError: await task
pass except asyncio.CancelledError:
pass

View File

@ -5,6 +5,8 @@ MIT License
""" """
from collections import deque from collections import deque
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import ( from typing import (
Any, Any,
AsyncContextManager, AsyncContextManager,
@ -18,6 +20,7 @@ from typing import (
List, List,
Optional, Optional,
Tuple, Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@ -207,3 +210,40 @@ class Tee(Generic[T]):
atee = Tee atee = Tee
class aclosing(AbstractAsyncContextManager):
"""Async context manager for safely finalizing an asynchronously cleaned-up
resource such as an async generator, calling its ``aclose()`` method.
Code like this:
async with aclosing(<module>.fetch(<arguments>)) as agen:
<block>
is equivalent to this:
agen = <module>.fetch(<arguments>)
try:
<block>
finally:
await agen.aclose()
"""
def __init__(
self, thing: Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]
) -> None:
self.thing = thing
async def __aenter__(self) -> Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]:
return self.thing
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if hasattr(self.thing, "aclose"):
await self.thing.aclose()

View File

@ -1,4 +1,5 @@
"""Module that contains tests for runnable.astream_events API.""" """Module that contains tests for runnable.astream_events API."""
import asyncio
import sys import sys
import uuid import uuid
from itertools import cycle from itertools import cycle
@ -38,6 +39,7 @@ from langchain_core.runnables import (
RunnableConfig, RunnableConfig,
RunnableGenerator, RunnableGenerator,
RunnableLambda, RunnableLambda,
chain,
ensure_config, ensure_config,
) )
from langchain_core.runnables.config import get_callback_manager_for_config from langchain_core.runnables.config import get_callback_manager_for_config
@ -45,6 +47,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_core.utils.aiter import aclosing
from tests.unit_tests.stubs import AnyStr from tests.unit_tests.stubs import AnyStr
@ -2195,3 +2198,140 @@ async def test_with_explicit_config() -> None:
for event in events for event in events
if event["event"] == "on_chat_model_stream" if event["event"] == "on_chat_model_stream"
] == ["hello", " ", "world"] ] == ["hello", " ", "world"]
async def test_break_astream_events() -> None:
class AwhileMaker:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
return input
except asyncio.CancelledError:
self.cancelled = True
raise
def reset(self) -> None:
self.started = False
self.cancelled = False
alittlewhile = AwhileMaker()
awhile = AwhileMaker()
anotherwhile = AwhileMaker()
outer_cancelled = False
@chain
async def sequence(input: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
raise
# test interrupting astream_events v2
got_event = False
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
async with aclosing(
sequence.astream_events({"value": 1}, thread2, version="v2")
) as stream:
async for chunk in stream:
if chunk["event"] == "on_chain_stream":
got_event = True
assert chunk["data"]["chunk"] == {"value": 1}
break
# did break
assert got_event
# did cancel outer chain
assert outer_cancelled
# node "alittlewhile" starts, not cancelled
assert alittlewhile.started is True
assert alittlewhile.cancelled is False
# node "awhile" starts but is cancelled
assert awhile.started is True
assert awhile.cancelled is True
# node "anotherwhile" should never start
assert anotherwhile.started is False
async def test_cancel_astream_events() -> None:
class AwhileMaker:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
return input
except asyncio.CancelledError:
self.cancelled = True
raise
def reset(self) -> None:
self.started = False
self.cancelled = False
alittlewhile = AwhileMaker()
awhile = AwhileMaker()
anotherwhile = AwhileMaker()
outer_cancelled = False
@chain
async def sequence(input: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
raise
got_event = False
async def aconsume(stream: AsyncIterator[Any]) -> None:
nonlocal got_event
# here we don't need aclosing as cancelling the task is propagated
# to the async generator being consumed
async for chunk in stream:
if chunk["event"] == "on_chain_stream":
got_event = True
assert chunk["data"]["chunk"] == {"value": 1}
task.cancel()
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
task = asyncio.create_task(
aconsume(sequence.astream_events({"value": 1}, thread2, version="v2"))
)
with pytest.raises(asyncio.CancelledError):
await task
# did break
assert got_event
# did cancel outer chain
assert outer_cancelled
# node "alittlewhile" starts, not cancelled
assert alittlewhile.started is True
assert alittlewhile.cancelled is False
# node "awhile" starts but is cancelled
assert awhile.started is True
assert awhile.cancelled is True
# node "anotherwhile" should never start
assert anotherwhile.started is False