diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 21bf63a284f..8c36967b6e2 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -13,6 +13,7 @@ from operator import itemgetter from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, AsyncIterator, Awaitable, Callable, @@ -79,7 +80,7 @@ from langchain_core.runnables.utils import ( is_async_callable, is_async_generator, ) -from langchain_core.utils.aiter import atee, py_anext +from langchain_core.utils.aiter import aclosing, atee, py_anext from langchain_core.utils.iter import safetee if TYPE_CHECKING: @@ -1141,8 +1142,9 @@ class Runnable(Generic[Input, Output], ABC): 'Only versions "v1" and "v2" of the schema is currently supported.' ) - async for event in event_stream: - yield event + async with aclosing(event_stream): + async for event in event_stream: + yield event def transform( self, @@ -1948,7 +1950,7 @@ class Runnable(Generic[Input, Output], ABC): kwargs["run_manager"] = run_manager context = copy_context() context.run(_set_config_context, child_config) - iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] + iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] if stream_handler := next( ( @@ -1960,7 +1962,11 @@ class Runnable(Generic[Input, Output], ABC): None, ): # populates streamed_output in astream_log() output if needed - iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator) + iterator = stream_handler.tap_output_aiter( + run_manager.run_id, iterator_ + ) + else: + iterator = iterator_ try: while True: if accepts_context(asyncio.create_task): @@ -2001,6 +2007,9 @@ class Runnable(Generic[Input, Output], ABC): raise else: await run_manager.on_chain_end(final_output, inputs=final_input) + finally: + if hasattr(iterator_, "aclose"): + await iterator_.aclose() class RunnableSerializable(Serializable, Runnable[Input, Output]): @@ -3907,23 +3916,29 @@ class RunnableLambda(Runnable[Input, Output]): if is_async_generator(afunc): output: Optional[Output] = None - async for chunk in cast( - AsyncIterator[Output], - acall_func_with_variable_args( - cast(Callable, afunc), - input, - config, - run_manager, - **kwargs, - ), - ): - if output is None: - output = chunk - else: - try: - output = output + chunk # type: ignore[operator] - except TypeError: + async with aclosing( + cast( + AsyncGenerator[Any, Any], + acall_func_with_variable_args( + cast(Callable, afunc), + input, + config, + run_manager, + **kwargs, + ), + ) + ) as stream: + async for chunk in cast( + AsyncIterator[Output], + stream, + ): + if output is None: output = chunk + else: + try: + output = output + chunk # type: ignore[operator] + except TypeError: + output = chunk else: output = await acall_func_with_variable_args( cast(Callable, afunc), input, config, run_manager, **kwargs diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index c7839dca300..ad2d4eb6924 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -37,7 +37,7 @@ from langchain_core.runnables.utils import ( from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers.log_stream import LogEntry from langchain_core.tracers.memory_stream import _MemoryStream -from langchain_core.utils.aiter import py_anext +from langchain_core.utils.aiter import aclosing, py_anext if TYPE_CHECKING: from langchain_core.documents import Document @@ -903,11 +903,10 @@ async def _astream_events_implementation_v2( async def consume_astream() -> None: try: # if astream also calls tap_output_aiter this will be a no-op - async for _ in event_streamer.tap_output_aiter( - run_id, runnable.astream(input, config, **kwargs) - ): - # All the content will be picked up - pass + async with aclosing(runnable.astream(input, config, **kwargs)) as stream: + async for _ in event_streamer.tap_output_aiter(run_id, stream): + # All the content will be picked up + pass finally: await event_streamer.send_stream.aclose() @@ -942,7 +941,8 @@ async def _astream_events_implementation_v2( yield event finally: # Wait for the runnable to finish, if not cancelled (eg. by break) - try: - await task - except asyncio.CancelledError: - pass + if task.cancel(): + try: + await task + except asyncio.CancelledError: + pass diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index 837b5473849..f7a9db2d048 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -5,6 +5,8 @@ MIT License """ from collections import deque +from contextlib import AbstractAsyncContextManager +from types import TracebackType from typing import ( Any, AsyncContextManager, @@ -18,6 +20,7 @@ from typing import ( List, Optional, Tuple, + Type, TypeVar, Union, cast, @@ -207,3 +210,40 @@ class Tee(Generic[T]): atee = Tee + + +class aclosing(AbstractAsyncContextManager): + """Async context manager for safely finalizing an asynchronously cleaned-up + resource such as an async generator, calling its ``aclose()`` method. + + Code like this: + + async with aclosing(.fetch()) as agen: + + + is equivalent to this: + + agen = .fetch() + try: + + finally: + await agen.aclose() + + """ + + def __init__( + self, thing: Union[AsyncGenerator[Any, Any], AsyncIterator[Any]] + ) -> None: + self.thing = thing + + async def __aenter__(self) -> Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]: + return self.thing + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + if hasattr(self.thing, "aclose"): + await self.thing.aclose() diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 81924b14bce..a708bb417ea 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -1,4 +1,5 @@ """Module that contains tests for runnable.astream_events API.""" +import asyncio import sys import uuid from itertools import cycle @@ -38,6 +39,7 @@ from langchain_core.runnables import ( RunnableConfig, RunnableGenerator, RunnableLambda, + chain, ensure_config, ) from langchain_core.runnables.config import get_callback_manager_for_config @@ -45,6 +47,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.utils import Input, Output from langchain_core.tools import tool +from langchain_core.utils.aiter import aclosing from tests.unit_tests.stubs import AnyStr @@ -2195,3 +2198,140 @@ async def test_with_explicit_config() -> None: for event in events if event["event"] == "on_chat_model_stream" ] == ["hello", " ", "world"] + + +async def test_break_astream_events() -> None: + class AwhileMaker: + def __init__(self) -> None: + self.reset() + + async def __call__(self, input: Any) -> Any: + self.started = True + try: + await asyncio.sleep(0.5) + return input + except asyncio.CancelledError: + self.cancelled = True + raise + + def reset(self) -> None: + self.started = False + self.cancelled = False + + alittlewhile = AwhileMaker() + awhile = AwhileMaker() + anotherwhile = AwhileMaker() + + outer_cancelled = False + + @chain + async def sequence(input: Any) -> Any: + try: + yield await alittlewhile(input) + yield await awhile(input) + yield await anotherwhile(input) + except asyncio.CancelledError: + nonlocal outer_cancelled + outer_cancelled = True + raise + + # test interrupting astream_events v2 + + got_event = False + thread2: RunnableConfig = {"configurable": {"thread_id": 2}} + async with aclosing( + sequence.astream_events({"value": 1}, thread2, version="v2") + ) as stream: + async for chunk in stream: + if chunk["event"] == "on_chain_stream": + got_event = True + assert chunk["data"]["chunk"] == {"value": 1} + break + + # did break + assert got_event + # did cancel outer chain + assert outer_cancelled + + # node "alittlewhile" starts, not cancelled + assert alittlewhile.started is True + assert alittlewhile.cancelled is False + + # node "awhile" starts but is cancelled + assert awhile.started is True + assert awhile.cancelled is True + + # node "anotherwhile" should never start + assert anotherwhile.started is False + + +async def test_cancel_astream_events() -> None: + class AwhileMaker: + def __init__(self) -> None: + self.reset() + + async def __call__(self, input: Any) -> Any: + self.started = True + try: + await asyncio.sleep(0.5) + return input + except asyncio.CancelledError: + self.cancelled = True + raise + + def reset(self) -> None: + self.started = False + self.cancelled = False + + alittlewhile = AwhileMaker() + awhile = AwhileMaker() + anotherwhile = AwhileMaker() + + outer_cancelled = False + + @chain + async def sequence(input: Any) -> Any: + try: + yield await alittlewhile(input) + yield await awhile(input) + yield await anotherwhile(input) + except asyncio.CancelledError: + nonlocal outer_cancelled + outer_cancelled = True + raise + + got_event = False + + async def aconsume(stream: AsyncIterator[Any]) -> None: + nonlocal got_event + # here we don't need aclosing as cancelling the task is propagated + # to the async generator being consumed + async for chunk in stream: + if chunk["event"] == "on_chain_stream": + got_event = True + assert chunk["data"]["chunk"] == {"value": 1} + task.cancel() + + thread2: RunnableConfig = {"configurable": {"thread_id": 2}} + task = asyncio.create_task( + aconsume(sequence.astream_events({"value": 1}, thread2, version="v2")) + ) + + with pytest.raises(asyncio.CancelledError): + await task + + # did break + assert got_event + # did cancel outer chain + assert outer_cancelled + + # node "alittlewhile" starts, not cancelled + assert alittlewhile.started is True + assert alittlewhile.cancelled is False + + # node "awhile" starts but is cancelled + assert awhile.started is True + assert awhile.cancelled is True + + # node "anotherwhile" should never start + assert anotherwhile.started is False