mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 10:42:55 +00:00
core: In astream_events v2 propagate cancel/break to the inner astream call (#22865)
- previous behavior was for the inner astream to continue running with no interruption - also propagate break in core runnable methods
This commit is contained in:
parent
a766815a99
commit
bae82e966a
@ -13,6 +13,7 @@ from operator import itemgetter
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
@ -79,7 +80,7 @@ from langchain_core.runnables.utils import (
|
|||||||
is_async_callable,
|
is_async_callable,
|
||||||
is_async_generator,
|
is_async_generator,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.aiter import atee, py_anext
|
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
||||||
from langchain_core.utils.iter import safetee
|
from langchain_core.utils.iter import safetee
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -1141,8 +1142,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
'Only versions "v1" and "v2" of the schema is currently supported.'
|
'Only versions "v1" and "v2" of the schema is currently supported.'
|
||||||
)
|
)
|
||||||
|
|
||||||
async for event in event_stream:
|
async with aclosing(event_stream):
|
||||||
yield event
|
async for event in event_stream:
|
||||||
|
yield event
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
self,
|
self,
|
||||||
@ -1948,7 +1950,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
kwargs["run_manager"] = run_manager
|
kwargs["run_manager"] = run_manager
|
||||||
context = copy_context()
|
context = copy_context()
|
||||||
context.run(_set_config_context, child_config)
|
context.run(_set_config_context, child_config)
|
||||||
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
|
||||||
|
|
||||||
if stream_handler := next(
|
if stream_handler := next(
|
||||||
(
|
(
|
||||||
@ -1960,7 +1962,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
None,
|
None,
|
||||||
):
|
):
|
||||||
# populates streamed_output in astream_log() output if needed
|
# populates streamed_output in astream_log() output if needed
|
||||||
iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator)
|
iterator = stream_handler.tap_output_aiter(
|
||||||
|
run_manager.run_id, iterator_
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
iterator = iterator_
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if accepts_context(asyncio.create_task):
|
if accepts_context(asyncio.create_task):
|
||||||
@ -2001,6 +2007,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
await run_manager.on_chain_end(final_output, inputs=final_input)
|
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||||
|
finally:
|
||||||
|
if hasattr(iterator_, "aclose"):
|
||||||
|
await iterator_.aclose()
|
||||||
|
|
||||||
|
|
||||||
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||||
@ -3907,23 +3916,29 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
|
|
||||||
if is_async_generator(afunc):
|
if is_async_generator(afunc):
|
||||||
output: Optional[Output] = None
|
output: Optional[Output] = None
|
||||||
async for chunk in cast(
|
async with aclosing(
|
||||||
AsyncIterator[Output],
|
cast(
|
||||||
acall_func_with_variable_args(
|
AsyncGenerator[Any, Any],
|
||||||
cast(Callable, afunc),
|
acall_func_with_variable_args(
|
||||||
input,
|
cast(Callable, afunc),
|
||||||
config,
|
input,
|
||||||
run_manager,
|
config,
|
||||||
**kwargs,
|
run_manager,
|
||||||
),
|
**kwargs,
|
||||||
):
|
),
|
||||||
if output is None:
|
)
|
||||||
output = chunk
|
) as stream:
|
||||||
else:
|
async for chunk in cast(
|
||||||
try:
|
AsyncIterator[Output],
|
||||||
output = output + chunk # type: ignore[operator]
|
stream,
|
||||||
except TypeError:
|
):
|
||||||
|
if output is None:
|
||||||
output = chunk
|
output = chunk
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
output = output + chunk # type: ignore[operator]
|
||||||
|
except TypeError:
|
||||||
|
output = chunk
|
||||||
else:
|
else:
|
||||||
output = await acall_func_with_variable_args(
|
output = await acall_func_with_variable_args(
|
||||||
cast(Callable, afunc), input, config, run_manager, **kwargs
|
cast(Callable, afunc), input, config, run_manager, **kwargs
|
||||||
|
@ -37,7 +37,7 @@ from langchain_core.runnables.utils import (
|
|||||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||||
from langchain_core.tracers.log_stream import LogEntry
|
from langchain_core.tracers.log_stream import LogEntry
|
||||||
from langchain_core.tracers.memory_stream import _MemoryStream
|
from langchain_core.tracers.memory_stream import _MemoryStream
|
||||||
from langchain_core.utils.aiter import py_anext
|
from langchain_core.utils.aiter import aclosing, py_anext
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -903,11 +903,10 @@ async def _astream_events_implementation_v2(
|
|||||||
async def consume_astream() -> None:
|
async def consume_astream() -> None:
|
||||||
try:
|
try:
|
||||||
# if astream also calls tap_output_aiter this will be a no-op
|
# if astream also calls tap_output_aiter this will be a no-op
|
||||||
async for _ in event_streamer.tap_output_aiter(
|
async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
|
||||||
run_id, runnable.astream(input, config, **kwargs)
|
async for _ in event_streamer.tap_output_aiter(run_id, stream):
|
||||||
):
|
# All the content will be picked up
|
||||||
# All the content will be picked up
|
pass
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
await event_streamer.send_stream.aclose()
|
await event_streamer.send_stream.aclose()
|
||||||
|
|
||||||
@ -942,7 +941,8 @@ async def _astream_events_implementation_v2(
|
|||||||
yield event
|
yield event
|
||||||
finally:
|
finally:
|
||||||
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
||||||
try:
|
if task.cancel():
|
||||||
await task
|
try:
|
||||||
except asyncio.CancelledError:
|
await task
|
||||||
pass
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
@ -5,6 +5,8 @@ MIT License
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from contextlib import AbstractAsyncContextManager
|
||||||
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
@ -18,6 +20,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -207,3 +210,40 @@ class Tee(Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
atee = Tee
|
atee = Tee
|
||||||
|
|
||||||
|
|
||||||
|
class aclosing(AbstractAsyncContextManager):
|
||||||
|
"""Async context manager for safely finalizing an asynchronously cleaned-up
|
||||||
|
resource such as an async generator, calling its ``aclose()`` method.
|
||||||
|
|
||||||
|
Code like this:
|
||||||
|
|
||||||
|
async with aclosing(<module>.fetch(<arguments>)) as agen:
|
||||||
|
<block>
|
||||||
|
|
||||||
|
is equivalent to this:
|
||||||
|
|
||||||
|
agen = <module>.fetch(<arguments>)
|
||||||
|
try:
|
||||||
|
<block>
|
||||||
|
finally:
|
||||||
|
await agen.aclose()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, thing: Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]
|
||||||
|
) -> None:
|
||||||
|
self.thing = thing
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]:
|
||||||
|
return self.thing
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_value: Optional[BaseException],
|
||||||
|
traceback: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
if hasattr(self.thing, "aclose"):
|
||||||
|
await self.thing.aclose()
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Module that contains tests for runnable.astream_events API."""
|
"""Module that contains tests for runnable.astream_events API."""
|
||||||
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
@ -38,6 +39,7 @@ from langchain_core.runnables import (
|
|||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
RunnableGenerator,
|
RunnableGenerator,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
|
chain,
|
||||||
ensure_config,
|
ensure_config,
|
||||||
)
|
)
|
||||||
from langchain_core.runnables.config import get_callback_manager_for_config
|
from langchain_core.runnables.config import get_callback_manager_for_config
|
||||||
@ -45,6 +47,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
|
|||||||
from langchain_core.runnables.schema import StreamEvent
|
from langchain_core.runnables.schema import StreamEvent
|
||||||
from langchain_core.runnables.utils import Input, Output
|
from langchain_core.runnables.utils import Input, Output
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.utils.aiter import aclosing
|
||||||
from tests.unit_tests.stubs import AnyStr
|
from tests.unit_tests.stubs import AnyStr
|
||||||
|
|
||||||
|
|
||||||
@ -2195,3 +2198,140 @@ async def test_with_explicit_config() -> None:
|
|||||||
for event in events
|
for event in events
|
||||||
if event["event"] == "on_chat_model_stream"
|
if event["event"] == "on_chat_model_stream"
|
||||||
] == ["hello", " ", "world"]
|
] == ["hello", " ", "world"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_break_astream_events() -> None:
|
||||||
|
class AwhileMaker:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
async def __call__(self, input: Any) -> Any:
|
||||||
|
self.started = True
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
return input
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.cancelled = True
|
||||||
|
raise
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.started = False
|
||||||
|
self.cancelled = False
|
||||||
|
|
||||||
|
alittlewhile = AwhileMaker()
|
||||||
|
awhile = AwhileMaker()
|
||||||
|
anotherwhile = AwhileMaker()
|
||||||
|
|
||||||
|
outer_cancelled = False
|
||||||
|
|
||||||
|
@chain
|
||||||
|
async def sequence(input: Any) -> Any:
|
||||||
|
try:
|
||||||
|
yield await alittlewhile(input)
|
||||||
|
yield await awhile(input)
|
||||||
|
yield await anotherwhile(input)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
nonlocal outer_cancelled
|
||||||
|
outer_cancelled = True
|
||||||
|
raise
|
||||||
|
|
||||||
|
# test interrupting astream_events v2
|
||||||
|
|
||||||
|
got_event = False
|
||||||
|
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
|
||||||
|
async with aclosing(
|
||||||
|
sequence.astream_events({"value": 1}, thread2, version="v2")
|
||||||
|
) as stream:
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk["event"] == "on_chain_stream":
|
||||||
|
got_event = True
|
||||||
|
assert chunk["data"]["chunk"] == {"value": 1}
|
||||||
|
break
|
||||||
|
|
||||||
|
# did break
|
||||||
|
assert got_event
|
||||||
|
# did cancel outer chain
|
||||||
|
assert outer_cancelled
|
||||||
|
|
||||||
|
# node "alittlewhile" starts, not cancelled
|
||||||
|
assert alittlewhile.started is True
|
||||||
|
assert alittlewhile.cancelled is False
|
||||||
|
|
||||||
|
# node "awhile" starts but is cancelled
|
||||||
|
assert awhile.started is True
|
||||||
|
assert awhile.cancelled is True
|
||||||
|
|
||||||
|
# node "anotherwhile" should never start
|
||||||
|
assert anotherwhile.started is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cancel_astream_events() -> None:
|
||||||
|
class AwhileMaker:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
async def __call__(self, input: Any) -> Any:
|
||||||
|
self.started = True
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
return input
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
self.cancelled = True
|
||||||
|
raise
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.started = False
|
||||||
|
self.cancelled = False
|
||||||
|
|
||||||
|
alittlewhile = AwhileMaker()
|
||||||
|
awhile = AwhileMaker()
|
||||||
|
anotherwhile = AwhileMaker()
|
||||||
|
|
||||||
|
outer_cancelled = False
|
||||||
|
|
||||||
|
@chain
|
||||||
|
async def sequence(input: Any) -> Any:
|
||||||
|
try:
|
||||||
|
yield await alittlewhile(input)
|
||||||
|
yield await awhile(input)
|
||||||
|
yield await anotherwhile(input)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
nonlocal outer_cancelled
|
||||||
|
outer_cancelled = True
|
||||||
|
raise
|
||||||
|
|
||||||
|
got_event = False
|
||||||
|
|
||||||
|
async def aconsume(stream: AsyncIterator[Any]) -> None:
|
||||||
|
nonlocal got_event
|
||||||
|
# here we don't need aclosing as cancelling the task is propagated
|
||||||
|
# to the async generator being consumed
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk["event"] == "on_chain_stream":
|
||||||
|
got_event = True
|
||||||
|
assert chunk["data"]["chunk"] == {"value": 1}
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
|
||||||
|
task = asyncio.create_task(
|
||||||
|
aconsume(sequence.astream_events({"value": 1}, thread2, version="v2"))
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
# did break
|
||||||
|
assert got_event
|
||||||
|
# did cancel outer chain
|
||||||
|
assert outer_cancelled
|
||||||
|
|
||||||
|
# node "alittlewhile" starts, not cancelled
|
||||||
|
assert alittlewhile.started is True
|
||||||
|
assert alittlewhile.cancelled is False
|
||||||
|
|
||||||
|
# node "awhile" starts but is cancelled
|
||||||
|
assert awhile.started is True
|
||||||
|
assert awhile.cancelled is True
|
||||||
|
|
||||||
|
# node "anotherwhile" should never start
|
||||||
|
assert anotherwhile.started is False
|
||||||
|
Loading…
Reference in New Issue
Block a user