mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-30 17:29:56 +00:00
core[patch]: In astream_events(version=v2) tap output of root run (#21977)
- if tap_output_iter/aiter is called multiple times for the same run issue events only once - if chat model run is tapped don't issue duplicate on_llm_new_token events - if first chunk arrives after run has ended do not emit it as a stream event --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
72d4a8eeed
commit
fb6108c8f5
@ -124,7 +124,6 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
metadata={},
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
)
|
||||
if var_config := var_child_runnable_config.get():
|
||||
empty.update(
|
||||
|
@ -17,7 +17,7 @@ from typing import (
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
@ -37,6 +37,7 @@ from langchain_core.runnables.utils import (
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.tracers.log_stream import LogEntry
|
||||
from langchain_core.tracers.memory_stream import _MemoryStream
|
||||
from langchain_core.utils.aiter import py_anext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.documents import Document
|
||||
@ -87,6 +88,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
super().__init__(*args, **kwargs)
|
||||
# Map of run ID to run info.
|
||||
self.run_map: Dict[UUID, RunInfo] = {}
|
||||
self.is_tapped: Dict[UUID, Any] = {}
|
||||
|
||||
# Filter which events will be sent over the queue.
|
||||
self.root_event_filter = _RootEventFilter(
|
||||
@ -116,41 +118,85 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
self, run_id: UUID, output: AsyncIterator[T]
|
||||
) -> AsyncIterator[T]:
|
||||
"""Tap the output aiter."""
|
||||
async for chunk in output:
|
||||
run_info = self.run_map.get(run_id)
|
||||
if run_info is None:
|
||||
raise AssertionError(f"Run ID {run_id} not found in run map.")
|
||||
self._send(
|
||||
{
|
||||
"event": f"on_{run_info['run_type']}_stream",
|
||||
"data": {"chunk": chunk},
|
||||
"run_id": str(run_id),
|
||||
"name": run_info["name"],
|
||||
"tags": run_info["tags"],
|
||||
"metadata": run_info["metadata"],
|
||||
},
|
||||
run_info["run_type"],
|
||||
)
|
||||
yield chunk
|
||||
sentinel = object()
|
||||
# atomic check and set
|
||||
tap = self.is_tapped.setdefault(run_id, sentinel)
|
||||
# wait for first chunk
|
||||
first = await py_anext(output, default=sentinel)
|
||||
if first is sentinel:
|
||||
return
|
||||
# get run info
|
||||
run_info = self.run_map.get(run_id)
|
||||
if run_info is None:
|
||||
# run has finished, don't issue any stream events
|
||||
yield cast(T, first)
|
||||
return
|
||||
if tap is sentinel:
|
||||
# if we are the first to tap, issue stream events
|
||||
event: StreamEvent = {
|
||||
"event": f"on_{run_info['run_type']}_stream",
|
||||
"run_id": str(run_id),
|
||||
"name": run_info["name"],
|
||||
"tags": run_info["tags"],
|
||||
"metadata": run_info["metadata"],
|
||||
"data": {},
|
||||
}
|
||||
self._send({**event, "data": {"chunk": first}}, run_info["run_type"])
|
||||
yield cast(T, first)
|
||||
# consume the rest of the output
|
||||
async for chunk in output:
|
||||
self._send(
|
||||
{**event, "data": {"chunk": chunk}},
|
||||
run_info["run_type"],
|
||||
)
|
||||
yield chunk
|
||||
else:
|
||||
# otherwise just pass through
|
||||
yield cast(T, first)
|
||||
# consume the rest of the output
|
||||
async for chunk in output:
|
||||
yield chunk
|
||||
|
||||
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
|
||||
"""Tap the output aiter."""
|
||||
for chunk in output:
|
||||
run_info = self.run_map.get(run_id)
|
||||
if run_info is None:
|
||||
raise AssertionError(f"Run ID {run_id} not found in run map.")
|
||||
self._send(
|
||||
{
|
||||
"event": f"on_{run_info['run_type']}_stream",
|
||||
"data": {"chunk": chunk},
|
||||
"run_id": str(run_id),
|
||||
"name": run_info["name"],
|
||||
"tags": run_info["tags"],
|
||||
"metadata": run_info["metadata"],
|
||||
},
|
||||
run_info["run_type"],
|
||||
)
|
||||
yield chunk
|
||||
sentinel = object()
|
||||
# atomic check and set
|
||||
tap = self.is_tapped.setdefault(run_id, sentinel)
|
||||
# wait for first chunk
|
||||
first = next(output, sentinel)
|
||||
if first is sentinel:
|
||||
return
|
||||
# get run info
|
||||
run_info = self.run_map.get(run_id)
|
||||
if run_info is None:
|
||||
# run has finished, don't issue any stream events
|
||||
yield cast(T, first)
|
||||
return
|
||||
if tap is sentinel:
|
||||
# if we are the first to tap, issue stream events
|
||||
event: StreamEvent = {
|
||||
"event": f"on_{run_info['run_type']}_stream",
|
||||
"run_id": str(run_id),
|
||||
"name": run_info["name"],
|
||||
"tags": run_info["tags"],
|
||||
"metadata": run_info["metadata"],
|
||||
"data": {},
|
||||
}
|
||||
self._send({**event, "data": {"chunk": first}}, run_info["run_type"])
|
||||
yield cast(T, first)
|
||||
# consume the rest of the output
|
||||
for chunk in output:
|
||||
self._send(
|
||||
{**event, "data": {"chunk": chunk}},
|
||||
run_info["run_type"],
|
||||
)
|
||||
yield chunk
|
||||
else:
|
||||
# otherwise just pass through
|
||||
yield cast(T, first)
|
||||
# consume the rest of the output
|
||||
for chunk in output:
|
||||
yield chunk
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
@ -244,6 +290,8 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
|
||||
if run_info is None:
|
||||
raise AssertionError(f"Run ID {run_id} not found in run map.")
|
||||
if self.is_tapped.get(run_id):
|
||||
return
|
||||
if run_info["run_type"] == "chat_model":
|
||||
event = "on_chat_model_stream"
|
||||
|
||||
@ -748,6 +796,7 @@ async def _astream_events_implementation_v2(
|
||||
|
||||
# Assign the stream handler to the config
|
||||
config = ensure_config(config)
|
||||
run_id = cast(UUID, config.setdefault("run_id", uuid4()))
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks is None:
|
||||
config["callbacks"] = [event_streamer]
|
||||
@ -767,7 +816,10 @@ async def _astream_events_implementation_v2(
|
||||
# add each chunk to the output stream
|
||||
async def consume_astream() -> None:
|
||||
try:
|
||||
async for _ in runnable.astream(input, config, **kwargs):
|
||||
# if astream also calls tap_output_aiter this will be a no-op
|
||||
async for _ in event_streamer.tap_output_aiter(
|
||||
run_id, runnable.astream(input, config, **kwargs)
|
||||
):
|
||||
# All the content will be picked up
|
||||
pass
|
||||
finally:
|
||||
|
@ -1440,7 +1440,6 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
|
||||
recursion_limit=25,
|
||||
configurable={"hello": "there"},
|
||||
metadata={"hello": "there", "bye": "now"},
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
spy.reset_mock()
|
||||
@ -1582,7 +1581,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
@ -1592,7 +1590,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
),
|
||||
]
|
||||
@ -1618,7 +1615,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
|
||||
@ -1629,7 +1625,6 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
recursion_limit=5,
|
||||
run_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
@ -1700,7 +1695,6 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
recursion_limit=25,
|
||||
run_id=None,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,7 +1,17 @@
|
||||
"""Module that contains tests for runnable.astream_events API."""
|
||||
import sys
|
||||
from itertools import cycle
|
||||
from typing import Any, AsyncIterator, Dict, List, Sequence, cast
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
cast,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
@ -9,6 +19,7 @@ from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -25,9 +36,12 @@ from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
ensure_config,
|
||||
)
|
||||
from langchain_core.runnables.config import get_callback_manager_for_config
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langchain_core.runnables.utils import Input, Output
|
||||
from langchain_core.tools import tool
|
||||
from tests.unit_tests.stubs import AnyStr
|
||||
|
||||
@ -1707,3 +1721,112 @@ async def test_sync_in_sync_lambdas() -> None:
|
||||
|
||||
events = await _collect_events(add_one_proxy_.astream_events(1, version="v2"))
|
||||
assert events == EXPECTED_EVENTS
|
||||
|
||||
|
||||
class StreamingRunnable(Runnable[Input, Output]):
|
||||
"""A custom runnable used for testing purposes"""
|
||||
|
||||
iterable: Iterable[Any]
|
||||
|
||||
def __init__(self, iterable: Iterable[Any]) -> None:
|
||||
"""Initialize the runnable."""
|
||||
self.iterable = iterable
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
"""Invoke the runnable."""
|
||||
raise ValueError("Server side error")
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name", self.get_name()),
|
||||
run_id=config.get("run_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
final_output = None
|
||||
for element in self.iterable:
|
||||
if isinstance(element, BaseException):
|
||||
raise element
|
||||
yield element
|
||||
|
||||
if final_output is None:
|
||||
final_output = element
|
||||
else:
|
||||
try:
|
||||
final_output = final_output + element
|
||||
except TypeError:
|
||||
final_output = element
|
||||
|
||||
# set final channel values as run output
|
||||
run_manager.on_chain_end(final_output)
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise
|
||||
|
||||
|
||||
async def test_astream_events_from_custom_runnable() -> None:
|
||||
"""Test astream events from a custom runnable."""
|
||||
iterator = ["1", "2", "3"]
|
||||
runnable: Runnable[int, str] = StreamingRunnable(iterator)
|
||||
chunks = [chunk async for chunk in runnable.astream(1, version="v2")]
|
||||
assert chunks == ["1", "2", "3"]
|
||||
events = await _collect_events(runnable.astream_events(1, version="v2"))
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": 1},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "StreamingRunnable",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": "1"},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "StreamingRunnable",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": "2"},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "StreamingRunnable",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": "3"},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "StreamingRunnable",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": "123"},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "StreamingRunnable",
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user