core[patch]: In astream_events(version=v2) tap output of root run (#21977)

- if tap_output_iter/aiter is called multiple times for the same run
issue events only once
- if chat model run is tapped don't issue duplicate on_llm_new_token
events
- if first chunk arrives after run has ended do not emit it as a stream
event

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Nuno Campos 2024-05-21 11:03:57 -07:00 committed by GitHub
parent 72d4a8eeed
commit fb6108c8f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 210 additions and 42 deletions

View File

@ -124,7 +124,6 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
metadata={},
callbacks=None,
recursion_limit=25,
run_id=None,
)
if var_config := var_child_runnable_config.get():
empty.update(

View File

@ -17,7 +17,7 @@ from typing import (
Union,
cast,
)
from uuid import UUID
from uuid import UUID, uuid4
from typing_extensions import NotRequired, TypedDict
@ -37,6 +37,7 @@ from langchain_core.runnables.utils import (
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.log_stream import LogEntry
from langchain_core.tracers.memory_stream import _MemoryStream
from langchain_core.utils.aiter import py_anext
if TYPE_CHECKING:
from langchain_core.documents import Document
@ -87,6 +88,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
super().__init__(*args, **kwargs)
# Map of run ID to run info.
self.run_map: Dict[UUID, RunInfo] = {}
self.is_tapped: Dict[UUID, Any] = {}
# Filter which events will be sent over the queue.
self.root_event_filter = _RootEventFilter(
@ -116,41 +118,85 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
"""Tap the output aiter."""
async for chunk in output:
run_info = self.run_map.get(run_id)
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
self._send(
{
"event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
yield chunk
sentinel = object()
# atomic check and set
tap = self.is_tapped.setdefault(run_id, sentinel)
# wait for first chunk
first = await py_anext(output, default=sentinel)
if first is sentinel:
return
# get run info
run_info = self.run_map.get(run_id)
if run_info is None:
# run has finished, don't issue any stream events
yield cast(T, first)
return
if tap is sentinel:
# if we are the first to tap, issue stream events
event: StreamEvent = {
"event": f"on_{run_info['run_type']}_stream",
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"data": {},
}
self._send({**event, "data": {"chunk": first}}, run_info["run_type"])
yield cast(T, first)
# consume the rest of the output
async for chunk in output:
self._send(
{**event, "data": {"chunk": chunk}},
run_info["run_type"],
)
yield chunk
else:
# otherwise just pass through
yield cast(T, first)
# consume the rest of the output
async for chunk in output:
yield chunk
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap the output aiter."""
for chunk in output:
run_info = self.run_map.get(run_id)
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
self._send(
{
"event": f"on_{run_info['run_type']}_stream",
"data": {"chunk": chunk},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
},
run_info["run_type"],
)
yield chunk
sentinel = object()
# atomic check and set
tap = self.is_tapped.setdefault(run_id, sentinel)
# wait for first chunk
first = next(output, sentinel)
if first is sentinel:
return
# get run info
run_info = self.run_map.get(run_id)
if run_info is None:
# run has finished, don't issue any stream events
yield cast(T, first)
return
if tap is sentinel:
# if we are the first to tap, issue stream events
event: StreamEvent = {
"event": f"on_{run_info['run_type']}_stream",
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"data": {},
}
self._send({**event, "data": {"chunk": first}}, run_info["run_type"])
yield cast(T, first)
# consume the rest of the output
for chunk in output:
self._send(
{**event, "data": {"chunk": chunk}},
run_info["run_type"],
)
yield chunk
else:
# otherwise just pass through
yield cast(T, first)
# consume the rest of the output
for chunk in output:
yield chunk
async def on_chat_model_start(
self,
@ -244,6 +290,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
if run_info is None:
raise AssertionError(f"Run ID {run_id} not found in run map.")
if self.is_tapped.get(run_id):
return
if run_info["run_type"] == "chat_model":
event = "on_chat_model_stream"
@ -748,6 +796,7 @@ async def _astream_events_implementation_v2(
# Assign the stream handler to the config
config = ensure_config(config)
run_id = cast(UUID, config.setdefault("run_id", uuid4()))
callbacks = config.get("callbacks")
if callbacks is None:
config["callbacks"] = [event_streamer]
@ -767,7 +816,10 @@ async def _astream_events_implementation_v2(
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for _ in runnable.astream(input, config, **kwargs):
# if astream also calls tap_output_aiter this will be a no-op
async for _ in event_streamer.tap_output_aiter(
run_id, runnable.astream(input, config, **kwargs)
):
# All the content will be picked up
pass
finally:

View File

@ -1440,7 +1440,6 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
recursion_limit=25,
configurable={"hello": "there"},
metadata={"hello": "there", "bye": "now"},
run_id=None,
),
)
spy.reset_mock()
@ -1582,7 +1581,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
),
mocker.call(
@ -1592,7 +1590,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
),
]
@ -1618,7 +1615,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
)
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
@ -1629,7 +1625,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
tags=["c"],
callbacks=None,
recursion_limit=5,
run_id=None,
),
)
@ -1700,7 +1695,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
tags=[],
callbacks=None,
recursion_limit=25,
run_id=None,
)

View File

@ -1,7 +1,17 @@
"""Module that contains tests for runnable.astream_events API."""
import sys
from itertools import cycle
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
cast,
)
import pytest
@ -9,6 +19,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.documents import Document
from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel
from langchain_core.load import dumpd
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -25,9 +36,12 @@ from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
ensure_config,
)
from langchain_core.runnables.config import get_callback_manager_for_config
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool
from tests.unit_tests.stubs import AnyStr
@ -1707,3 +1721,112 @@ async def test_sync_in_sync_lambdas() -> None:
events = await _collect_events(add_one_proxy_.astream_events(1, version="v2"))
assert events == EXPECTED_EVENTS
class StreamingRunnable(Runnable[Input, Output]):
"""A custom runnable used for testing purposes"""
iterable: Iterable[Any]
def __init__(self, iterable: Iterable[Any]) -> None:
"""Initialize the runnable."""
self.iterable = iterable
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
"""Invoke the runnable."""
raise ValueError("Server side error")
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
raise NotImplementedError()
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name", self.get_name()),
run_id=config.get("run_id"),
)
try:
final_output = None
for element in self.iterable:
if isinstance(element, BaseException):
raise element
yield element
if final_output is None:
final_output = element
else:
try:
final_output = final_output + element
except TypeError:
final_output = element
# set final channel values as run output
run_manager.on_chain_end(final_output)
except BaseException as e:
run_manager.on_chain_error(e)
raise
async def test_astream_events_from_custom_runnable() -> None:
"""Test astream events from a custom runnable."""
iterator = ["1", "2", "3"]
runnable: Runnable[int, str] = StreamingRunnable(iterator)
chunks = [chunk async for chunk in runnable.astream(1, version="v2")]
assert chunks == ["1", "2", "3"]
events = await _collect_events(runnable.astream_events(1, version="v2"))
assert events == [
{
"data": {"input": 1},
"event": "on_chain_start",
"metadata": {},
"name": "StreamingRunnable",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "1"},
"event": "on_chain_stream",
"metadata": {},
"name": "StreamingRunnable",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "2"},
"event": "on_chain_stream",
"metadata": {},
"name": "StreamingRunnable",
"run_id": "",
"tags": [],
},
{
"data": {"chunk": "3"},
"event": "on_chain_stream",
"metadata": {},
"name": "StreamingRunnable",
"run_id": "",
"tags": [],
},
{
"data": {"output": "123"},
"event": "on_chain_end",
"metadata": {},
"name": "StreamingRunnable",
"run_id": "",
"tags": [],
},
]