diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index b89a3da38fc..4bf32c605f0 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -124,7 +124,6 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: metadata={}, callbacks=None, recursion_limit=25, - run_id=None, ) if var_config := var_child_runnable_config.get(): empty.update( diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 40038a88209..ee591d46a1e 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -17,7 +17,7 @@ from typing import ( Union, cast, ) -from uuid import UUID +from uuid import UUID, uuid4 from typing_extensions import NotRequired, TypedDict @@ -37,6 +37,7 @@ from langchain_core.runnables.utils import ( from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers.log_stream import LogEntry from langchain_core.tracers.memory_stream import _MemoryStream +from langchain_core.utils.aiter import py_anext if TYPE_CHECKING: from langchain_core.documents import Document @@ -87,6 +88,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand super().__init__(*args, **kwargs) # Map of run ID to run info. self.run_map: Dict[UUID, RunInfo] = {} + self.is_tapped: Dict[UUID, Any] = {} # Filter which events will be sent over the queue. self.root_event_filter = _RootEventFilter( @@ -116,41 +118,85 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand self, run_id: UUID, output: AsyncIterator[T] ) -> AsyncIterator[T]: """Tap the output aiter.""" - async for chunk in output: - run_info = self.run_map.get(run_id) - if run_info is None: - raise AssertionError(f"Run ID {run_id} not found in run map.") - self._send( - { - "event": f"on_{run_info['run_type']}_stream", - "data": {"chunk": chunk}, - "run_id": str(run_id), - "name": run_info["name"], - "tags": run_info["tags"], - "metadata": run_info["metadata"], - }, - run_info["run_type"], - ) - yield chunk + sentinel = object() + # atomic check and set + tap = self.is_tapped.setdefault(run_id, sentinel) + # wait for first chunk + first = await py_anext(output, default=sentinel) + if first is sentinel: + return + # get run info + run_info = self.run_map.get(run_id) + if run_info is None: + # run has finished, don't issue any stream events + yield cast(T, first) + return + if tap is sentinel: + # if we are the first to tap, issue stream events + event: StreamEvent = { + "event": f"on_{run_info['run_type']}_stream", + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + "data": {}, + } + self._send({**event, "data": {"chunk": first}}, run_info["run_type"]) + yield cast(T, first) + # consume the rest of the output + async for chunk in output: + self._send( + {**event, "data": {"chunk": chunk}}, + run_info["run_type"], + ) + yield chunk + else: + # otherwise just pass through + yield cast(T, first) + # consume the rest of the output + async for chunk in output: + yield chunk def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]: """Tap the output aiter.""" - for chunk in output: - run_info = self.run_map.get(run_id) - if run_info is None: - raise AssertionError(f"Run ID {run_id} not found in run map.") - self._send( - { - "event": f"on_{run_info['run_type']}_stream", - "data": {"chunk": chunk}, - "run_id": str(run_id), - "name": run_info["name"], - "tags": run_info["tags"], - "metadata": run_info["metadata"], - }, - run_info["run_type"], - ) - yield chunk + sentinel = object() + # atomic check and set + tap = self.is_tapped.setdefault(run_id, sentinel) + # wait for first chunk + first = next(output, sentinel) + if first is sentinel: + return + # get run info + run_info = self.run_map.get(run_id) + if run_info is None: + # run has finished, don't issue any stream events + yield cast(T, first) + return + if tap is sentinel: + # if we are the first to tap, issue stream events + event: StreamEvent = { + "event": f"on_{run_info['run_type']}_stream", + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + "data": {}, + } + self._send({**event, "data": {"chunk": first}}, run_info["run_type"]) + yield cast(T, first) + # consume the rest of the output + for chunk in output: + self._send( + {**event, "data": {"chunk": chunk}}, + run_info["run_type"], + ) + yield chunk + else: + # otherwise just pass through + yield cast(T, first) + # consume the rest of the output + for chunk in output: + yield chunk async def on_chat_model_start( self, @@ -244,6 +290,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand if run_info is None: raise AssertionError(f"Run ID {run_id} not found in run map.") + if self.is_tapped.get(run_id): + return if run_info["run_type"] == "chat_model": event = "on_chat_model_stream" @@ -748,6 +796,7 @@ async def _astream_events_implementation_v2( # Assign the stream handler to the config config = ensure_config(config) + run_id = cast(UUID, config.setdefault("run_id", uuid4())) callbacks = config.get("callbacks") if callbacks is None: config["callbacks"] = [event_streamer] @@ -767,7 +816,10 @@ async def _astream_events_implementation_v2( # add each chunk to the output stream async def consume_astream() -> None: try: - async for _ in runnable.astream(input, config, **kwargs): + # if astream also calls tap_output_aiter this will be a no-op + async for _ in event_streamer.tap_output_aiter( + run_id, runnable.astream(input, config, **kwargs) + ): # All the content will be picked up pass finally: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index dd643a5e54c..3fbd7c57c18 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1440,7 +1440,6 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: recursion_limit=25, configurable={"hello": "there"}, metadata={"hello": "there", "bye": "now"}, - run_id=None, ), ) spy.reset_mock() @@ -1582,7 +1581,6 @@ async def test_with_config(mocker: MockerFixture) -> None: tags=["c"], callbacks=None, recursion_limit=5, - run_id=None, ), ), mocker.call( @@ -1592,7 +1590,6 @@ async def test_with_config(mocker: MockerFixture) -> None: tags=["c"], callbacks=None, recursion_limit=5, - run_id=None, ), ), ] @@ -1618,7 +1615,6 @@ async def test_with_config(mocker: MockerFixture) -> None: tags=["c"], callbacks=None, recursion_limit=5, - run_id=None, ), ) second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld") @@ -1629,7 +1625,6 @@ async def test_with_config(mocker: MockerFixture) -> None: tags=["c"], callbacks=None, recursion_limit=5, - run_id=None, ), ) @@ -1700,7 +1695,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: tags=[], callbacks=None, recursion_limit=25, - run_id=None, ) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index e5a90c1f2c9..4ec4d4a3500 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -1,7 +1,17 @@ """Module that contains tests for runnable.astream_events API.""" import sys from itertools import cycle -from typing import Any, AsyncIterator, Dict, List, Sequence, cast +from typing import ( + Any, + AsyncIterator, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + cast, +) import pytest @@ -9,6 +19,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.documents import Document from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel +from langchain_core.load import dumpd from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -25,9 +36,12 @@ from langchain_core.runnables import ( Runnable, RunnableConfig, RunnableLambda, + ensure_config, ) +from langchain_core.runnables.config import get_callback_manager_for_config from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent +from langchain_core.runnables.utils import Input, Output from langchain_core.tools import tool from tests.unit_tests.stubs import AnyStr @@ -1707,3 +1721,112 @@ async def test_sync_in_sync_lambdas() -> None: events = await _collect_events(add_one_proxy_.astream_events(1, version="v2")) assert events == EXPECTED_EVENTS + + +class StreamingRunnable(Runnable[Input, Output]): + """A custom runnable used for testing purposes""" + + iterable: Iterable[Any] + + def __init__(self, iterable: Iterable[Any]) -> None: + """Initialize the runnable.""" + self.iterable = iterable + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + """Invoke the runnable.""" + raise ValueError("Server side error") + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + raise NotImplementedError() + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name", self.get_name()), + run_id=config.get("run_id"), + ) + + try: + final_output = None + for element in self.iterable: + if isinstance(element, BaseException): + raise element + yield element + + if final_output is None: + final_output = element + else: + try: + final_output = final_output + element + except TypeError: + final_output = element + + # set final channel values as run output + run_manager.on_chain_end(final_output) + except BaseException as e: + run_manager.on_chain_error(e) + raise + + +async def test_astream_events_from_custom_runnable() -> None: + """Test astream events from a custom runnable.""" + iterator = ["1", "2", "3"] + runnable: Runnable[int, str] = StreamingRunnable(iterator) + chunks = [chunk async for chunk in runnable.astream(1, version="v2")] + assert chunks == ["1", "2", "3"] + events = await _collect_events(runnable.astream_events(1, version="v2")) + assert events == [ + { + "data": {"input": 1}, + "event": "on_chain_start", + "metadata": {}, + "name": "StreamingRunnable", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "1"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "StreamingRunnable", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "2"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "StreamingRunnable", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "3"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "StreamingRunnable", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "123"}, + "event": "on_chain_end", + "metadata": {}, + "name": "StreamingRunnable", + "run_id": "", + "tags": [], + }, + ]