mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
core: In astream_events v2 propagate cancel/break to the inner astream call (#22865)
- previous behavior was for the inner astream to continue running with no interruption - also propagate break in core runnable methods
This commit is contained in:
parent
a766815a99
commit
bae82e966a
@ -13,6 +13,7 @@ from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
@ -79,7 +80,7 @@ from langchain_core.runnables.utils import (
|
||||
is_async_callable,
|
||||
is_async_generator,
|
||||
)
|
||||
from langchain_core.utils.aiter import atee, py_anext
|
||||
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
||||
from langchain_core.utils.iter import safetee
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1141,8 +1142,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
'Only versions "v1" and "v2" of the schema is currently supported.'
|
||||
)
|
||||
|
||||
async for event in event_stream:
|
||||
yield event
|
||||
async with aclosing(event_stream):
|
||||
async for event in event_stream:
|
||||
yield event
|
||||
|
||||
def transform(
|
||||
self,
|
||||
@ -1948,7 +1950,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
kwargs["run_manager"] = run_manager
|
||||
context = copy_context()
|
||||
context.run(_set_config_context, child_config)
|
||||
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
||||
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
if stream_handler := next(
|
||||
(
|
||||
@ -1960,7 +1962,11 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
None,
|
||||
):
|
||||
# populates streamed_output in astream_log() output if needed
|
||||
iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator)
|
||||
iterator = stream_handler.tap_output_aiter(
|
||||
run_manager.run_id, iterator_
|
||||
)
|
||||
else:
|
||||
iterator = iterator_
|
||||
try:
|
||||
while True:
|
||||
if accepts_context(asyncio.create_task):
|
||||
@ -2001,6 +2007,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
raise
|
||||
else:
|
||||
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||
finally:
|
||||
if hasattr(iterator_, "aclose"):
|
||||
await iterator_.aclose()
|
||||
|
||||
|
||||
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
@ -3907,23 +3916,29 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
if is_async_generator(afunc):
|
||||
output: Optional[Output] = None
|
||||
async for chunk in cast(
|
||||
AsyncIterator[Output],
|
||||
acall_func_with_variable_args(
|
||||
cast(Callable, afunc),
|
||||
input,
|
||||
config,
|
||||
run_manager,
|
||||
**kwargs,
|
||||
),
|
||||
):
|
||||
if output is None:
|
||||
output = chunk
|
||||
else:
|
||||
try:
|
||||
output = output + chunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
async with aclosing(
|
||||
cast(
|
||||
AsyncGenerator[Any, Any],
|
||||
acall_func_with_variable_args(
|
||||
cast(Callable, afunc),
|
||||
input,
|
||||
config,
|
||||
run_manager,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
) as stream:
|
||||
async for chunk in cast(
|
||||
AsyncIterator[Output],
|
||||
stream,
|
||||
):
|
||||
if output is None:
|
||||
output = chunk
|
||||
else:
|
||||
try:
|
||||
output = output + chunk # type: ignore[operator]
|
||||
except TypeError:
|
||||
output = chunk
|
||||
else:
|
||||
output = await acall_func_with_variable_args(
|
||||
cast(Callable, afunc), input, config, run_manager, **kwargs
|
||||
|
@ -37,7 +37,7 @@ from langchain_core.runnables.utils import (
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.tracers.log_stream import LogEntry
|
||||
from langchain_core.tracers.memory_stream import _MemoryStream
|
||||
from langchain_core.utils.aiter import py_anext
|
||||
from langchain_core.utils.aiter import aclosing, py_anext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
@ -903,11 +903,10 @@ async def _astream_events_implementation_v2(
|
||||
async def consume_astream() -> None:
|
||||
try:
|
||||
# if astream also calls tap_output_aiter this will be a no-op
|
||||
async for _ in event_streamer.tap_output_aiter(
|
||||
run_id, runnable.astream(input, config, **kwargs)
|
||||
):
|
||||
# All the content will be picked up
|
||||
pass
|
||||
async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
|
||||
async for _ in event_streamer.tap_output_aiter(run_id, stream):
|
||||
# All the content will be picked up
|
||||
pass
|
||||
finally:
|
||||
await event_streamer.send_stream.aclose()
|
||||
|
||||
@ -942,7 +941,8 @@ async def _astream_events_implementation_v2(
|
||||
yield event
|
||||
finally:
|
||||
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if task.cancel():
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
@ -5,6 +5,8 @@ MIT License
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
@ -18,6 +20,7 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -207,3 +210,40 @@ class Tee(Generic[T]):
|
||||
|
||||
|
||||
atee = Tee
|
||||
|
||||
|
||||
class aclosing(AbstractAsyncContextManager):
|
||||
"""Async context manager for safely finalizing an asynchronously cleaned-up
|
||||
resource such as an async generator, calling its ``aclose()`` method.
|
||||
|
||||
Code like this:
|
||||
|
||||
async with aclosing(<module>.fetch(<arguments>)) as agen:
|
||||
<block>
|
||||
|
||||
is equivalent to this:
|
||||
|
||||
agen = <module>.fetch(<arguments>)
|
||||
try:
|
||||
<block>
|
||||
finally:
|
||||
await agen.aclose()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, thing: Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]
|
||||
) -> None:
|
||||
self.thing = thing
|
||||
|
||||
async def __aenter__(self) -> Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]:
|
||||
return self.thing
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
if hasattr(self.thing, "aclose"):
|
||||
await self.thing.aclose()
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Module that contains tests for runnable.astream_events API."""
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from itertools import cycle
|
||||
@ -38,6 +39,7 @@ from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
chain,
|
||||
ensure_config,
|
||||
)
|
||||
from langchain_core.runnables.config import get_callback_manager_for_config
|
||||
@ -45,6 +47,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.utils.aiter import aclosing
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
|
||||
@ -2195,3 +2198,140 @@ async def test_with_explicit_config() -> None:
|
||||
for event in events
|
||||
if event["event"] == "on_chat_model_stream"
|
||||
] == ["hello", " ", "world"]
|
||||
|
||||
|
||||
async def test_break_astream_events() -> None:
|
||||
class AwhileMaker:
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
async def __call__(self, input: Any) -> Any:
|
||||
self.started = True
|
||||
try:
|
||||
await asyncio.sleep(0.5)
|
||||
return input
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.started = False
|
||||
self.cancelled = False
|
||||
|
||||
alittlewhile = AwhileMaker()
|
||||
awhile = AwhileMaker()
|
||||
anotherwhile = AwhileMaker()
|
||||
|
||||
outer_cancelled = False
|
||||
|
||||
@chain
|
||||
async def sequence(input: Any) -> Any:
|
||||
try:
|
||||
yield await alittlewhile(input)
|
||||
yield await awhile(input)
|
||||
yield await anotherwhile(input)
|
||||
except asyncio.CancelledError:
|
||||
nonlocal outer_cancelled
|
||||
outer_cancelled = True
|
||||
raise
|
||||
|
||||
# test interrupting astream_events v2
|
||||
|
||||
got_event = False
|
||||
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
|
||||
async with aclosing(
|
||||
sequence.astream_events({"value": 1}, thread2, version="v2")
|
||||
) as stream:
|
||||
async for chunk in stream:
|
||||
if chunk["event"] == "on_chain_stream":
|
||||
got_event = True
|
||||
assert chunk["data"]["chunk"] == {"value": 1}
|
||||
break
|
||||
|
||||
# did break
|
||||
assert got_event
|
||||
# did cancel outer chain
|
||||
assert outer_cancelled
|
||||
|
||||
# node "alittlewhile" starts, not cancelled
|
||||
assert alittlewhile.started is True
|
||||
assert alittlewhile.cancelled is False
|
||||
|
||||
# node "awhile" starts but is cancelled
|
||||
assert awhile.started is True
|
||||
assert awhile.cancelled is True
|
||||
|
||||
# node "anotherwhile" should never start
|
||||
assert anotherwhile.started is False
|
||||
|
||||
|
||||
async def test_cancel_astream_events() -> None:
|
||||
class AwhileMaker:
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
async def __call__(self, input: Any) -> Any:
|
||||
self.started = True
|
||||
try:
|
||||
await asyncio.sleep(0.5)
|
||||
return input
|
||||
except asyncio.CancelledError:
|
||||
self.cancelled = True
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.started = False
|
||||
self.cancelled = False
|
||||
|
||||
alittlewhile = AwhileMaker()
|
||||
awhile = AwhileMaker()
|
||||
anotherwhile = AwhileMaker()
|
||||
|
||||
outer_cancelled = False
|
||||
|
||||
@chain
|
||||
async def sequence(input: Any) -> Any:
|
||||
try:
|
||||
yield await alittlewhile(input)
|
||||
yield await awhile(input)
|
||||
yield await anotherwhile(input)
|
||||
except asyncio.CancelledError:
|
||||
nonlocal outer_cancelled
|
||||
outer_cancelled = True
|
||||
raise
|
||||
|
||||
got_event = False
|
||||
|
||||
async def aconsume(stream: AsyncIterator[Any]) -> None:
|
||||
nonlocal got_event
|
||||
# here we don't need aclosing as cancelling the task is propagated
|
||||
# to the async generator being consumed
|
||||
async for chunk in stream:
|
||||
if chunk["event"] == "on_chain_stream":
|
||||
got_event = True
|
||||
assert chunk["data"]["chunk"] == {"value": 1}
|
||||
task.cancel()
|
||||
|
||||
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
|
||||
task = asyncio.create_task(
|
||||
aconsume(sequence.astream_events({"value": 1}, thread2, version="v2"))
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
# did break
|
||||
assert got_event
|
||||
# did cancel outer chain
|
||||
assert outer_cancelled
|
||||
|
||||
# node "alittlewhile" starts, not cancelled
|
||||
assert alittlewhile.started is True
|
||||
assert alittlewhile.cancelled is False
|
||||
|
||||
# node "awhile" starts but is cancelled
|
||||
assert awhile.started is True
|
||||
assert awhile.cancelled is True
|
||||
|
||||
# node "anotherwhile" should never start
|
||||
assert anotherwhile.started is False
|
||||
|
Loading…
Reference in New Issue
Block a user