From 5c2cfabec6fbc6d7ccdf0531291b67057b50986d Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 15 May 2024 11:48:47 -0400 Subject: [PATCH] core[minor]: Add v2 implementation of astream events (#21638) This PR introduces a v2 implementation of astream events that removes intermediate abstractions and fixes some issues with v1 implementation. The v2 implementation significantly reduces relevant code that's associated with the astream events implementation together with overhead. After this PR, the astream events implementation: - Uses an async callback handler - No longer relies on BaseTracer - No longer relies on json patch As a result of this re-write, a number of issues were discovered with the existing implementation. ## Changes in V2 vs. V1 ### on_chat_model_end `output` The outputs associated with `on_chat_model_end` changed depending on whether it was within a chain or not. As a root level runnable the output was: ```python "data": {"output": AIMessageChunk(content="hello world!", id='some id')} ``` As part of a chain the output was: ``` "data": { "output": { "generations": [ [ { "generation_info": None, "message": AIMessageChunk( content="hello world!", id=AnyStr() ), "text": "hello world!", "type": "ChatGenerationChunk", } ] ], "llm_output": None, } }, ``` After this PR, we will always use the simpler representation: ```python "data": {"output": AIMessageChunk(content="hello world!", id='some id')} ``` **NOTE** Non chat models (i.e., regular LLMs) are still associated with the more verbose format. ### Remove some `_stream` events `on_retriever_stream` and `on_tool_stream` events were removed -- these were not real events, but created as an artifact of implementing on top of astream_log. The same information is already available in the `x_on_end` events. ### Propagating Names Names of runnables have been updated to be more consistent ```python model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields( messages=ConfigurableField( id="messages", name="Messages", description="Messages return by the LLM", ) ) ``` Before: ```python "name": "RunnableConfigurableFields", ``` After: ```python "name": "GenericFakeChatModel", ``` ### on_retriever_end on_retriever_end will always return `output` which is a list of documents (rather than a dict containing a key called "documents") ### Retry events Removed the `on_retry` callback handler. It was incorrectly showing that the failed function being retried has invoked `on_chain_end` https://github.com/langchain-ai/langchain/pull/21638/files#diff-e512e3f84daf23029ebcceb11460f1c82056314653673e450a5831147d8cb84dL1394 --- .../language_models/chat_models.py | 6 +- libs/core/langchain_core/runnables/base.py | 243 +-- .../core/langchain_core/tracers/_streaming.py | 28 + .../langchain_core/tracers/event_stream.py | 790 ++++++++ .../core/langchain_core/tracers/log_stream.py | 3 +- libs/core/poetry.lock | 5 +- ...e_events.py => test_runnable_events_v1.py} | 0 .../runnables/test_runnable_events_v2.py | 1694 +++++++++++++++++ 8 files changed, 2574 insertions(+), 195 deletions(-) create mode 100644 libs/core/langchain_core/tracers/_streaming.py create mode 100644 libs/core/langchain_core/tracers/event_stream.py rename libs/core/tests/unit_tests/runnables/{test_runnable_events.py => test_runnable_events_v1.py} (100%) create mode 100644 libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 6abd5880553..f6967f51c09 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -52,7 +52,7 @@ from langchain_core.outputs import ( from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables.config import ensure_config, run_in_executor -from langchain_core.tracers.log_stream import LogStreamCallbackHandler +from langchain_core.tracers._streaming import _StreamingCallbackHandler if TYPE_CHECKING: from langchain_core.pydantic_v1 import BaseModel @@ -608,7 +608,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ( True for h in run_manager.handlers - if isinstance(h, LogStreamCallbackHandler) + if isinstance(h, _StreamingCallbackHandler) ), False, ) @@ -691,7 +691,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ( True for h in run_manager.handlers - if isinstance(h, LogStreamCallbackHandler) + if isinstance(h, _StreamingCallbackHandler) ), False, ) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index def410ba693..c2dfa420183 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -58,7 +58,7 @@ from langchain_core.runnables.config import ( var_child_runnable_config, ) from langchain_core.runnables.graph import Graph -from langchain_core.runnables.schema import EventData, StreamEvent +from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.utils import ( AddableDict, AnyConfigurableField, @@ -90,7 +90,6 @@ if TYPE_CHECKING: RunnableWithFallbacks as RunnableWithFallbacksT, ) from langchain_core.tracers.log_stream import ( - LogEntry, RunLog, RunLogPatch, ) @@ -927,7 +926,7 @@ class Runnable(Generic[Input, Output], ABC): input: Any, config: Optional[RunnableConfig] = None, *, - version: Literal["v1"], + version: Literal["v1", "v2"], include_names: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None, @@ -962,6 +961,8 @@ class Runnable(Generic[Input, Output], ABC): chains. Metadata fields have been omitted from the table for brevity. Chain definitions have been included after the table. + **ATTENTION** This reference table is for the V2 version of the schema. + +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | event | name | chunk | input | output | +======================+==================+=================================+===============================================+=================================================+ @@ -969,7 +970,7 @@ class Runnable(Generic[Input, Output], ABC): +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_chat_model_stream | [model name] | AIMessageChunk(content="hello") | | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ - | on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | {"generations": [...], "llm_output": None, ...} | + | on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | AIMessageChunk(content="hello world") | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_llm_start | [model name] | | {'input': 'hello'} | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ @@ -985,15 +986,11 @@ class Runnable(Generic[Input, Output], ABC): +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_tool_start | some_tool | | {"x": 1, "y": "2"} | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ - | on_tool_stream | some_tool | {"x": 1, "y": "2"} | | | - +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_tool_end | some_tool | | | {"x": 1, "y": "2"} | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_retriever_start | [retriever name] | | {"query": "hello"} | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ - | on_retriever_chunk | [retriever name] | {documents: [...]} | | | - +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ - | on_retriever_end | [retriever name] | | {"query": "hello"} | {documents: [...]} | + | on_retriever_end | [retriever name] | | {"query": "hello"} | [Document(...), ..] | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_prompt_start | [template_name] | | {"question": "hello"} | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ @@ -1042,7 +1039,7 @@ class Runnable(Generic[Input, Output], ABC): chain = RunnableLambda(func=reverse) events = [ - event async for event in chain.astream_events("hello", version="v1") + event async for event in chain.astream_events("hello", version="v2") ] # will produce the following events (run_id has been omitted for brevity): @@ -1073,8 +1070,10 @@ class Runnable(Generic[Input, Output], ABC): Args: input: The input to the runnable. config: The config to use for the runnable. - version: The version of the schema to use. - Currently only version 1 is available. + version: The version of the schema to use either `v2` or `v1`. + Users should use `v2`. + `v1` is for backwards compatibility and will be deprecated + in 0.4.0. No default will be assigned until the API is stabilized. include_names: Only include events from runnables with matching names. include_types: Only include events from runnables with matching types. @@ -1089,180 +1088,45 @@ class Runnable(Generic[Input, Output], ABC): Returns: An async stream of StreamEvents. """ # noqa: E501 - if version != "v1": + from langchain_core.tracers.event_stream import ( + _astream_events_implementation_v1, + _astream_events_implementation_v2, + ) + + if version == "v2": + event_stream = _astream_events_implementation_v2( + self, + input, + config=config, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + **kwargs, + ) + elif version == "v1": + # First implementation, built on top of astream_log API + # This implementation will be deprecated as of 0.2.0 + event_stream = _astream_events_implementation_v1( + self, + input, + config=config, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + **kwargs, + ) + else: raise NotImplementedError( - 'Only version "v1" of the schema is currently supported.' + 'Only versions "v1" and "v2" of the schema is currently supported.' ) - from langchain_core.runnables.utils import ( - _RootEventFilter, - ) - from langchain_core.tracers.log_stream import ( - LogStreamCallbackHandler, - RunLog, - _astream_log_implementation, - ) - - stream = LogStreamCallbackHandler( - auto_close=False, - include_names=include_names, - include_types=include_types, - include_tags=include_tags, - exclude_names=exclude_names, - exclude_types=exclude_types, - exclude_tags=exclude_tags, - _schema_format="streaming_events", - ) - - run_log = RunLog(state=None) # type: ignore[arg-type] - encountered_start_event = False - - _root_event_filter = _RootEventFilter( - include_names=include_names, - include_types=include_types, - include_tags=include_tags, - exclude_names=exclude_names, - exclude_types=exclude_types, - exclude_tags=exclude_tags, - ) - - config = ensure_config(config) - root_tags = config.get("tags", []) - root_metadata = config.get("metadata", {}) - root_name = config.get("run_name", self.get_name()) - - # Ignoring mypy complaint about too many different union combinations - # This arises because many of the argument types are unions - async for log in _astream_log_implementation( # type: ignore[misc] - self, - input, - config=config, - stream=stream, - diff=True, - with_streamed_output_list=True, - **kwargs, - ): - run_log = run_log + log - - if not encountered_start_event: - # Yield the start event for the root runnable. - encountered_start_event = True - state = run_log.state.copy() - - event = StreamEvent( - event=f"on_{state['type']}_start", - run_id=state["id"], - name=root_name, - tags=root_tags, - metadata=root_metadata, - data={ - "input": input, - }, - ) - - if _root_event_filter.include_event(event, state["type"]): - yield event - - paths = { - op["path"].split("/")[2] - for op in log.ops - if op["path"].startswith("/logs/") - } - # Elements in a set should be iterated in the same order - # as they were inserted in modern python versions. - for path in paths: - data: EventData = {} - log_entry: LogEntry = run_log.state["logs"][path] - if log_entry["end_time"] is None: - if log_entry["streamed_output"]: - event_type = "stream" - else: - event_type = "start" - else: - event_type = "end" - - if event_type == "start": - # Include the inputs with the start event if they are available. - # Usually they will NOT be available for components that operate - # on streams, since those components stream the input and - # don't know its final value until the end of the stream. - inputs = log_entry["inputs"] - if inputs is not None: - data["input"] = inputs - pass - - if event_type == "end": - inputs = log_entry["inputs"] - if inputs is not None: - data["input"] = inputs - - # None is a VALID output for an end event - data["output"] = log_entry["final_output"] - - if event_type == "stream": - num_chunks = len(log_entry["streamed_output"]) - if num_chunks != 1: - raise AssertionError( - f"Expected exactly one chunk of streamed output, " - f"got {num_chunks} instead. This is impossible. " - f"Encountered in: {log_entry['name']}" - ) - - data = {"chunk": log_entry["streamed_output"][0]} - # Clean up the stream, we don't need it anymore. - # And this avoids duplicates as well! - log_entry["streamed_output"] = [] - - yield StreamEvent( - event=f"on_{log_entry['type']}_{event_type}", - name=log_entry["name"], - run_id=log_entry["id"], - tags=log_entry["tags"], - metadata=log_entry["metadata"], - data=data, - ) - - # Finally, we take care of the streaming output from the root chain - # if there is any. - state = run_log.state - if state["streamed_output"]: - num_chunks = len(state["streamed_output"]) - if num_chunks != 1: - raise AssertionError( - f"Expected exactly one chunk of streamed output, " - f"got {num_chunks} instead. This is impossible. " - f"Encountered in: {state['name']}" - ) - - data = {"chunk": state["streamed_output"][0]} - # Clean up the stream, we don't need it anymore. - state["streamed_output"] = [] - - event = StreamEvent( - event=f"on_{state['type']}_stream", - run_id=state["id"], - tags=root_tags, - metadata=root_metadata, - name=root_name, - data=data, - ) - if _root_event_filter.include_event(event, state["type"]): - yield event - - state = run_log.state - - # Finally yield the end event for the root runnable. - event = StreamEvent( - event=f"on_{state['type']}_end", - name=root_name, - run_id=state["id"], - tags=root_tags, - metadata=root_metadata, - data={ - "output": state["final_output"], - }, - ) - if _root_event_filter.include_event(event, state["type"]): + async for event in event_stream: yield event def transform( @@ -1936,7 +1800,8 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" - from langchain_core.tracers.log_stream import LogStreamCallbackHandler + # Mixin that is used by both astream log and astream events implementation + from langchain_core.tracers._streaming import _StreamingCallbackHandler # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = atee(input, 2) @@ -1964,16 +1829,18 @@ class Runnable(Generic[Input, Output], ABC): context = copy_context() context.run(var_child_runnable_config.set, child_config) iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] - if stream_log := next( + + if stream_handler := next( ( - h + cast(_StreamingCallbackHandler, h) for h in run_manager.handlers - if isinstance(h, LogStreamCallbackHandler) + # instance check OK here, it's a mixin + if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc] ), None, ): # populates streamed_output in astream_log() output if needed - iterator = stream_log.tap_output_aiter(run_manager.run_id, iterator) + iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator) try: while True: if accepts_context(asyncio.create_task): diff --git a/libs/core/langchain_core/tracers/_streaming.py b/libs/core/langchain_core/tracers/_streaming.py new file mode 100644 index 00000000000..8c778daffa5 --- /dev/null +++ b/libs/core/langchain_core/tracers/_streaming.py @@ -0,0 +1,28 @@ +"""Internal tracers used for stream_log and astream events implementations.""" +import abc +from typing import AsyncIterator, TypeVar +from uuid import UUID + +T = TypeVar("T") + + +class _StreamingCallbackHandler(abc.ABC): + """For internal use. + + This is a common mixin that the callback handlers + for both astream events and astream log inherit from. + + The `tap_output_aiter` method is invoked in some contexts + to produce callbacks for intermediate results. + """ + + @abc.abstractmethod + def tap_output_aiter( + self, run_id: UUID, output: AsyncIterator[T] + ) -> AsyncIterator[T]: + """Used for internal astream_log and astream events implementations.""" + + +__all__ = [ + "_StreamingCallbackHandler", +] diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py new file mode 100644 index 00000000000..38fee911364 --- /dev/null +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -0,0 +1,790 @@ +"""Internal tracer to power the event stream API.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + List, + Optional, + Sequence, + TypeVar, + Union, + cast, +) +from uuid import UUID + +from typing_extensions import NotRequired, TypedDict + +from langchain_core.callbacks.base import AsyncCallbackHandler +from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk +from langchain_core.outputs import ( + ChatGenerationChunk, + GenerationChunk, + LLMResult, +) +from langchain_core.runnables.schema import EventData, StreamEvent +from langchain_core.runnables.utils import ( + Input, + Output, + _RootEventFilter, +) +from langchain_core.tracers._streaming import _StreamingCallbackHandler +from langchain_core.tracers.log_stream import LogEntry +from langchain_core.tracers.memory_stream import _MemoryStream + +if TYPE_CHECKING: + from langchain_core.documents import Document + from langchain_core.runnables import Runnable, RunnableConfig + +logger = logging.getLogger(__name__) + + +class RunInfo(TypedDict): + """Information about a run.""" + + name: str + tags: List[str] + metadata: Dict[str, Any] + run_type: str + inputs: NotRequired[Any] + + +def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> str: + """Assign a name to a run.""" + if name is not None: + return name + if "name" in serialized: + return serialized["name"] + elif "id" in serialized: + return serialized["id"][-1] + return "Unnamed" + + +T = TypeVar("T") + + +class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHandler): + """An implementation of an async callback handler for astream events.""" + + def __init__( + self, + *args: Any, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> None: + """Initialize the tracer.""" + super().__init__(*args, **kwargs) + # Map of run ID to run info. + self.run_map: Dict[UUID, RunInfo] = {} + + # Filter which events will be sent over the queue. + self.root_event_filter = _RootEventFilter( + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + ) + + loop = asyncio.get_event_loop() + memory_stream = _MemoryStream[StreamEvent](loop) + self.send_stream = memory_stream.get_send_stream() + self.receive_stream = memory_stream.get_receive_stream() + + async def _send(self, event: StreamEvent, event_type: str) -> None: + """Send an event to the stream.""" + if self.root_event_filter.include_event(event, event_type): + await self.send_stream.send(event) + + def __aiter__(self) -> AsyncIterator[Any]: + """Iterate over the receive stream.""" + return self.receive_stream.__aiter__() + + async def tap_output_aiter( + self, run_id: UUID, output: AsyncIterator[T] + ) -> AsyncIterator[T]: + """Tap the output aiter.""" + async for chunk in output: + run_info = self.run_map.get(run_id) + if run_info is None: + raise AssertionError(f"Run ID {run_id} not found in run map.") + await self._send( + { + "event": f"on_{run_info['run_type']}_stream", + "data": {"chunk": chunk}, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_info["run_type"], + ) + yield chunk + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Start a trace for an LLM run.""" + name_ = _assign_name(name, serialized) + run_type = "chat_model" + self.run_map[run_id] = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": run_type, + "inputs": {"messages": messages}, + } + + await self._send( + { + "event": "on_chat_model_start", + "data": { + "input": {"messages": messages}, + }, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + run_type, + ) + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Start a trace for an LLM run.""" + name_ = _assign_name(name, serialized) + run_type = "llm" + self.run_map[run_id] = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": run_type, + "inputs": {"prompts": prompts}, + } + + await self._send( + { + "event": "on_llm_start", + "data": { + "input": { + "prompts": prompts, + } + }, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + run_type, + ) + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + run_info = self.run_map.get(run_id) + + chunk_: Union[GenerationChunk, BaseMessageChunk] + + if run_info is None: + raise AssertionError(f"Run ID {run_id} not found in run map.") + if run_info["run_type"] == "chat_model": + event = "on_chat_model_stream" + + if chunk is None: + chunk_ = AIMessageChunk(content=token) + else: + chunk_ = cast(ChatGenerationChunk, chunk).message + + elif run_info["run_type"] == "llm": + event = "on_llm_stream" + if chunk is None: + chunk_ = GenerationChunk(text=token) + else: + chunk_ = cast(GenerationChunk, chunk) + else: + raise ValueError(f"Unexpected run type: {run_info['run_type']}") + + await self._send( + { + "event": event, + "data": { + "chunk": chunk_, + }, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_info["run_type"], + ) + + async def on_llm_end( + self, response: LLMResult, *, run_id: UUID, **kwargs: Any + ) -> None: + """End a trace for an LLM run.""" + run_info = self.run_map.pop(run_id) + inputs_ = run_info["inputs"] + + generations: Union[List[List[GenerationChunk]], List[List[ChatGenerationChunk]]] + output: Union[dict, BaseMessage] = {} + + if run_info["run_type"] == "chat_model": + generations = cast(List[List[ChatGenerationChunk]], response.generations) + for gen in generations: + if output != {}: + break + for chunk in gen: + output = chunk.message + break + + event = "on_chat_model_end" + elif run_info["run_type"] == "llm": + generations = cast(List[List[GenerationChunk]], response.generations) + output = { + "generations": [ + [ + { + "text": chunk.text, + "generation_info": chunk.generation_info, + "type": chunk.type, + } + for chunk in gen + ] + for gen in generations + ], + "llm_output": response.llm_output, + } + event = "on_llm_end" + else: + raise ValueError(f"Unexpected run type: {run_info['run_type']}") + + await self._send( + { + "event": event, + "data": {"output": output, "input": inputs_}, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_info["run_type"], + ) + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + run_type: Optional[str] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Start a trace for a chain run.""" + name_ = _assign_name(name, serialized) + run_type_ = run_type or "chain" + run_info: RunInfo = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": run_type_, + } + + data: EventData = {} + + # Work-around Runnable core code not sending input in some + # cases. + if inputs != {"input": ""}: + data["input"] = inputs + run_info["inputs"] = inputs + + self.run_map[run_id] = run_info + + await self._send( + { + "event": f"on_{run_type_}_start", + "data": data, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + run_type_, + ) + + async def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + inputs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """End a trace for a chain run.""" + run_info = self.run_map.pop(run_id) + run_type = run_info["run_type"] + + event = f"on_{run_type}_end" + + inputs = inputs or run_info.get("inputs") or {} + + data: EventData = { + "output": outputs, + "input": inputs, + } + + await self._send( + { + "event": event, + "data": data, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_type, + ) + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + inputs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Start a trace for a tool run.""" + name_ = _assign_name(name, serialized) + self.run_map[run_id] = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": "tool", + "inputs": inputs, + } + + await self._send( + { + "event": "on_tool_start", + "data": { + "input": inputs or {}, + }, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + "tool", + ) + + async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: + """End a trace for a tool run.""" + run_info = self.run_map.pop(run_id) + if "inputs" not in run_info: + raise AssertionError( + f"Run ID {run_id} is a tool call and is expected to have " + f"inputs associated with it." + ) + inputs = run_info["inputs"] + + await self._send( + { + "event": "on_tool_end", + "data": { + "output": output, + "input": inputs, + }, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + "tool", + ) + + async def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Run when Retriever starts running.""" + name_ = _assign_name(name, serialized) + run_type = "retriever" + self.run_map[run_id] = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": run_type, + "inputs": {"query": query}, + } + + await self._send( + { + "event": "on_retriever_start", + "data": { + "input": { + "query": query, + } + }, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + run_type, + ) + + async def on_retriever_end( + self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any + ) -> None: + """Run when Retriever ends running.""" + run_info = self.run_map.pop(run_id) + + await self._send( + { + "event": "on_retriever_end", + "data": { + "output": documents, + "input": run_info["inputs"], + }, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_info["run_type"], + ) + + def __deepcopy__(self, memo: dict) -> _AstreamEventsCallbackHandler: + """Deepcopy the tracer.""" + return self + + def __copy__(self) -> _AstreamEventsCallbackHandler: + """Copy the tracer.""" + return self + + +async def _astream_events_implementation_v1( + runnable: Runnable[Input, Output], + input: Any, + config: Optional[RunnableConfig] = None, + *, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, +) -> AsyncIterator[StreamEvent]: + from langchain_core.runnables import ensure_config + from langchain_core.runnables.utils import _RootEventFilter + from langchain_core.tracers.log_stream import ( + LogStreamCallbackHandler, + RunLog, + _astream_log_implementation, + ) + + stream = LogStreamCallbackHandler( + auto_close=False, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + _schema_format="streaming_events", + ) + + run_log = RunLog(state=None) # type: ignore[arg-type] + encountered_start_event = False + + _root_event_filter = _RootEventFilter( + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + ) + + config = ensure_config(config) + root_tags = config.get("tags", []) + root_metadata = config.get("metadata", {}) + root_name = config.get("run_name", runnable.get_name()) + + # Ignoring mypy complaint about too many different union combinations + # This arises because many of the argument types are unions + async for log in _astream_log_implementation( # type: ignore[misc] + runnable, + input, + config=config, + stream=stream, + diff=True, + with_streamed_output_list=True, + **kwargs, + ): + run_log = run_log + log + + if not encountered_start_event: + # Yield the start event for the root runnable. + encountered_start_event = True + state = run_log.state.copy() + + event = StreamEvent( + event=f"on_{state['type']}_start", + run_id=state["id"], + name=root_name, + tags=root_tags, + metadata=root_metadata, + data={ + "input": input, + }, + ) + + if _root_event_filter.include_event(event, state["type"]): + yield event + + paths = { + op["path"].split("/")[2] + for op in log.ops + if op["path"].startswith("/logs/") + } + # Elements in a set should be iterated in the same order + # as they were inserted in modern python versions. + for path in paths: + data: EventData = {} + log_entry: LogEntry = run_log.state["logs"][path] + if log_entry["end_time"] is None: + if log_entry["streamed_output"]: + event_type = "stream" + else: + event_type = "start" + else: + event_type = "end" + + if event_type == "start": + # Include the inputs with the start event if they are available. + # Usually they will NOT be available for components that operate + # on streams, since those components stream the input and + # don't know its final value until the end of the stream. + inputs = log_entry["inputs"] + if inputs is not None: + data["input"] = inputs + pass + + if event_type == "end": + inputs = log_entry["inputs"] + if inputs is not None: + data["input"] = inputs + + # None is a VALID output for an end event + data["output"] = log_entry["final_output"] + + if event_type == "stream": + num_chunks = len(log_entry["streamed_output"]) + if num_chunks != 1: + raise AssertionError( + f"Expected exactly one chunk of streamed output, " + f"got {num_chunks} instead. This is impossible. " + f"Encountered in: {log_entry['name']}" + ) + + data = {"chunk": log_entry["streamed_output"][0]} + # Clean up the stream, we don't need it anymore. + # And this avoids duplicates as well! + log_entry["streamed_output"] = [] + + yield StreamEvent( + event=f"on_{log_entry['type']}_{event_type}", + name=log_entry["name"], + run_id=log_entry["id"], + tags=log_entry["tags"], + metadata=log_entry["metadata"], + data=data, + ) + + # Finally, we take care of the streaming output from the root chain + # if there is any. + state = run_log.state + if state["streamed_output"]: + num_chunks = len(state["streamed_output"]) + if num_chunks != 1: + raise AssertionError( + f"Expected exactly one chunk of streamed output, " + f"got {num_chunks} instead. This is impossible. " + f"Encountered in: {state['name']}" + ) + + data = {"chunk": state["streamed_output"][0]} + # Clean up the stream, we don't need it anymore. + state["streamed_output"] = [] + + event = StreamEvent( + event=f"on_{state['type']}_stream", + run_id=state["id"], + tags=root_tags, + metadata=root_metadata, + name=root_name, + data=data, + ) + if _root_event_filter.include_event(event, state["type"]): + yield event + + state = run_log.state + + # Finally yield the end event for the root runnable. + event = StreamEvent( + event=f"on_{state['type']}_end", + name=root_name, + run_id=state["id"], + tags=root_tags, + metadata=root_metadata, + data={ + "output": state["final_output"], + }, + ) + if _root_event_filter.include_event(event, state["type"]): + yield event + + +async def _astream_events_implementation_v2( + runnable: Runnable[Input, Output], + input: Any, + config: Optional[RunnableConfig] = None, + *, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, +) -> AsyncIterator[StreamEvent]: + """Implementation of the astream events API for V2 runnables.""" + from langchain_core.callbacks.base import BaseCallbackManager + from langchain_core.runnables import ensure_config + + event_streamer = _AstreamEventsCallbackHandler( + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + ) + + # Assign the stream handler to the config + config = ensure_config(config) + callbacks = config.get("callbacks") + if callbacks is None: + config["callbacks"] = [event_streamer] + elif isinstance(callbacks, list): + config["callbacks"] = callbacks + [event_streamer] + elif isinstance(callbacks, BaseCallbackManager): + callbacks = callbacks.copy() + callbacks.add_handler(event_streamer, inherit=True) + config["callbacks"] = callbacks + else: + raise ValueError( + f"Unexpected type for callbacks: {callbacks}." + "Expected None, list or AsyncCallbackManager." + ) + + # Call the runnable in streaming mode, + # add each chunk to the output stream + async def consume_astream() -> None: + try: + async for _ in runnable.astream(input, config, **kwargs): + # All the content will be picked up + pass + finally: + await event_streamer.send_stream.aclose() + + # Start the runnable in a task, so we can start consuming output + task = asyncio.create_task(consume_astream()) + + first_event_sent = False + first_event_run_id = None + + try: + async for event in event_streamer: + if not first_event_sent: + first_event_sent = True + # This is a work-around an issue where the inputs into the + # chain are not available until the entire input is consumed. + # As a temporary solution, we'll modify the input to be the input + # that was passed into the chain. + event["data"]["input"] = input + first_event_run_id = event["run_id"] + yield event + continue + + if event["run_id"] == first_event_run_id and event["event"].endswith( + "_end" + ): + # If it's the end event corresponding to the root runnable + # we dont include the input in the event since it's guaranteed + # to be included in the first event. + if "input" in event["data"]: + del event["data"]["input"] + + yield event + finally: + # Wait for the runnable to finish, if not cancelled (eg. by break) + try: + await task + except asyncio.CancelledError: + pass diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 4c02f933e42..47c789b6783 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -26,6 +26,7 @@ from langchain_core.load.load import load from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables.utils import Input, Output +from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.memory_stream import _MemoryStream from langchain_core.tracers.schemas import Run @@ -157,7 +158,7 @@ class RunLog(RunLogPatch): T = TypeVar("T") -class LogStreamCallbackHandler(BaseTracer): +class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): """Tracer that streams run logs to a stream.""" def __init__( diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index f233ba2d61d..011d482f331 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1198,7 +1198,7 @@ files = [ [[package]] name = "langchain-text-splitters" -version = "0.0.1" +version = "0.0.2" description = "LangChain text splitting utilities" optional = false python-versions = ">=3.8.1,<4.0" @@ -2142,7 +2142,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py similarity index 100% rename from libs/core/tests/unit_tests/runnables/test_runnable_events.py rename to libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py new file mode 100644 index 00000000000..47c69fe5f90 --- /dev/null +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -0,0 +1,1694 @@ +"""Module that contains tests for runnable.astream_events API.""" +import sys +from itertools import cycle +from typing import Any, AsyncIterator, Dict, List, Sequence, cast + +import pytest + +from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.documents import Document +from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.prompt_values import ChatPromptValue +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables import ( + ConfigurableField, + Runnable, + RunnableConfig, + RunnableLambda, +) +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.runnables.schema import StreamEvent +from langchain_core.tools import tool +from tests.unit_tests.stubs import AnyStr + + +def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: + """Removes the run ids from events.""" + return cast(List[StreamEvent], [{**event, "run_id": ""} for event in events]) + + +async def _as_async_iterator(iterable: List) -> AsyncIterator: + """Converts an iterable into an async iterator.""" + for item in iterable: + yield item + + +async def _collect_events(events: AsyncIterator[StreamEvent]) -> List[StreamEvent]: + """Collect the events and remove the run ids.""" + materialized_events = [event async for event in events] + events_ = _with_nulled_run_id(materialized_events) + for event in events_: + event["tags"] = sorted(event["tags"]) + return events_ + + +async def test_event_stream_with_simple_function_tool() -> None: + """Test the event stream with a function and tool""" + + def foo(x: int) -> dict: + """Foo""" + return {"x": 5} + + @tool + def get_docs(x: int) -> List[Document]: + """Hello Doc""" + return [Document(page_content="hello")] + + chain = RunnableLambda(foo) | get_docs + events = await _collect_events(chain.astream_events({}, version="v2")) + assert events == [ + { + "event": "on_chain_start", + "run_id": "", + "name": "RunnableSequence", + "tags": [], + "metadata": {}, + "data": {"input": {}}, + }, + { + "event": "on_chain_start", + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + "metadata": {}, + "data": {}, + }, + { + "event": "on_chain_stream", + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + "metadata": {}, + "data": {"chunk": {"x": 5}}, + }, + { + "event": "on_chain_end", + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + "metadata": {}, + "data": {"input": {}, "output": {"x": 5}}, + }, + { + "event": "on_tool_start", + "name": "get_docs", + "run_id": "", + "tags": ["seq:step:2"], + "metadata": {}, + "data": {"input": {"x": 5}}, + }, + { + "event": "on_tool_end", + "name": "get_docs", + "run_id": "", + "tags": ["seq:step:2"], + "metadata": {}, + "data": {"input": {"x": 5}, "output": [Document(page_content="hello")]}, + }, + { + "event": "on_chain_stream", + "run_id": "", + "tags": [], + "metadata": {}, + "name": "RunnableSequence", + "data": {"chunk": [Document(page_content="hello")]}, + }, + { + "event": "on_chain_end", + "name": "RunnableSequence", + "run_id": "", + "tags": [], + "metadata": {}, + "data": {"output": [Document(page_content="hello")]}, + }, + ] + + +async def test_event_stream_with_single_lambda() -> None: + """Test the event stream with a tool.""" + + def reverse(s: str) -> str: + """Reverse a string.""" + return s[::-1] + + chain = RunnableLambda(func=reverse) + + events = await _collect_events(chain.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_triple_lambda() -> None: + def reverse(s: str) -> str: + """Reverse a string.""" + return s[::-1] + + r = RunnableLambda(func=reverse) + + chain = ( + r.with_config({"run_name": "1"}) + | r.with_config({"run_name": "2"}) + | r.with_config({"run_name": "3"}) + ) + events = await _collect_events(chain.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "2", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"input": "hello", "output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "hello"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "2", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["seq:step:3"], + }, + { + "data": {"input": "olleh", "output": "hello"}, + "event": "on_chain_end", + "metadata": {}, + "name": "2", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["seq:step:3"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": "hello", "output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["seq:step:3"], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_triple_lambda_test_filtering() -> None: + """Test filtering based on tags / names""" + + def reverse(s: str) -> str: + """Reverse a string.""" + return s[::-1] + + r = RunnableLambda(func=reverse) + + chain = ( + r.with_config({"run_name": "1"}) + | r.with_config({"run_name": "2", "tags": ["my_tag"]}) + | r.with_config({"run_name": "3", "tags": ["my_tag"]}) + ) + events = await _collect_events( + chain.astream_events("hello", include_names=["1"], version="v2") + ) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + ] + + events = await _collect_events( + chain.astream_events( + "hello", include_tags=["my_tag"], exclude_names=["2"], version="v2" + ) + ) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["my_tag", "seq:step:3"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["my_tag", "seq:step:3"], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["my_tag", "seq:step:3"], + }, + ] + + +async def test_event_stream_with_lambdas_from_lambda() -> None: + as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config( + {"run_name": "my_lambda"} + ) + events = await _collect_events( + as_lambdas.astream_events({"question": "hello"}, version="v2") + ) + assert events == [ + { + "data": {"input": {"question": "hello"}}, + "event": "on_chain_start", + "metadata": {}, + "name": "my_lambda", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": {"answer": "goodbye"}}, + "event": "on_chain_stream", + "metadata": {}, + "name": "my_lambda", + "run_id": "", + "tags": [], + }, + { + "data": {"output": {"answer": "goodbye"}}, + "event": "on_chain_end", + "metadata": {}, + "name": "my_lambda", + "run_id": "", + "tags": [], + }, + ] + + +async def test_astream_events_from_model() -> None: + """Test the output of a model.""" + infinite_cycle = cycle([AIMessage(content="hello world!")]) + # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces + model = ( + GenericFakeChatModel(messages=infinite_cycle) + .with_config( + { + "metadata": {"a": "b"}, + "tags": ["my_model"], + "run_name": "my_model", + } + ) + .bind(stop="") + ) + events = await _collect_events(model.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "output": AIMessageChunk(content="hello world!", id=AnyStr()), + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + ] + + +async def test_astream_with_model_in_chain() -> None: + """Scenarios with model when it is not the only runnable in the chain.""" + infinite_cycle = cycle([AIMessage(content="hello world!")]) + # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces + model = ( + GenericFakeChatModel(messages=infinite_cycle) + .with_config( + { + "metadata": {"a": "b"}, + "tags": ["my_model"], + "run_name": "my_model", + } + ) + .bind(stop="") + ) + + @RunnableLambda + def i_dont_stream(input: Any, config: RunnableConfig) -> Any: + if sys.version_info >= (3, 11): + return model.invoke(input) + else: + return model.invoke(input, config) + + events = await _collect_events(i_dont_stream.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"messages": [[HumanMessage(content="hello")]]}}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "input": {"messages": [[HumanMessage(content="hello")]]}, + "output": AIMessage(content="hello world!", id=AnyStr()), + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessage(content="hello world!", id=AnyStr())}, + "event": "on_chain_stream", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessage(content="hello world!", id=AnyStr())}, + "event": "on_chain_end", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + ] + + @RunnableLambda + async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: + if sys.version_info >= (3, 11): + return await model.ainvoke(input) + else: + return await model.ainvoke(input, config) + + events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"messages": [[HumanMessage(content="hello")]]}}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "input": {"messages": [[HumanMessage(content="hello")]]}, + "output": AIMessage(content="hello world!", id=AnyStr()), + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessage(content="hello world!", id=AnyStr())}, + "event": "on_chain_stream", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessage(content="hello world!", id=AnyStr())}, + "event": "on_chain_end", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_simple_chain() -> None: + """Test as event stream.""" + template = ChatPromptTemplate.from_messages( + [("system", "You are Cat Agent 007"), ("human", "{question}")] + ).with_config({"run_name": "my_template", "tags": ["my_template"]}) + + infinite_cycle = cycle( + [ + AIMessage(content="hello world!", id="ai1"), + AIMessage(content="goodbye world!", id="ai2"), + ] + ) + # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces + model = ( + GenericFakeChatModel(messages=infinite_cycle) + .with_config( + { + "metadata": {"a": "b"}, + "tags": ["my_model"], + "run_name": "my_model", + } + ) + .bind(stop="") + ) + + chain = (template | model).with_config( + { + "metadata": {"foo": "bar"}, + "tags": ["my_chain"], + "run_name": "my_chain", + } + ) + + events = await _collect_events( + chain.astream_events({"question": "hello"}, version="v2") + ) + assert events == [ + { + "data": {"input": {"question": "hello"}}, + "event": "on_chain_start", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + { + "data": {"input": {"question": "hello"}}, + "event": "on_prompt_start", + "metadata": {"foo": "bar"}, + "name": "my_template", + "run_id": "", + "tags": ["my_chain", "my_template", "seq:step:1"], + }, + { + "data": { + "input": {"question": "hello"}, + "output": ChatPromptValue( + messages=[ + SystemMessage(content="You are Cat Agent 007"), + HumanMessage(content="hello"), + ] + ), + }, + "event": "on_prompt_end", + "metadata": {"foo": "bar"}, + "name": "my_template", + "run_id": "", + "tags": ["my_chain", "my_template", "seq:step:1"], + }, + { + "data": { + "input": { + "messages": [ + [ + SystemMessage(content="You are Cat Agent 007"), + HumanMessage(content="hello"), + ] + ] + } + }, + "event": "on_chat_model_start", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id="ai1")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id="ai1")}, + "event": "on_chain_stream", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id="ai1")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id="ai1")}, + "event": "on_chain_stream", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id="ai1")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id="ai1")}, + "event": "on_chain_stream", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + { + "data": { + "input": { + "messages": [ + [ + SystemMessage(content="You are Cat Agent 007"), + HumanMessage(content="hello"), + ] + ] + }, + "output": AIMessageChunk(content="hello world!", id="ai1"), + }, + "event": "on_chat_model_end", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"output": AIMessageChunk(content="hello world!", id="ai1")}, + "event": "on_chain_end", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + ] + + +async def test_event_streaming_with_tools() -> None: + """Test streaming events with different tool definitions.""" + + @tool + def parameterless() -> str: + """A tool that does nothing.""" + return "hello" + + @tool + def with_callbacks(callbacks: Callbacks) -> str: + """A tool that does nothing.""" + return "world" + + @tool + def with_parameters(x: int, y: str) -> dict: + """A tool that does nothing.""" + return {"x": x, "y": y} + + @tool + def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: + """A tool that does nothing.""" + return {"x": x, "y": y} + + # type ignores below because the tools don't appear to be runnables to type checkers + # we can remove as soon as that's fixed + events = await _collect_events(parameterless.astream_events({}, version="v2")) # type: ignore + assert events == [ + { + "data": {"input": {}}, + "event": "on_tool_start", + "metadata": {}, + "name": "parameterless", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "hello"}, + "event": "on_tool_end", + "metadata": {}, + "name": "parameterless", + "run_id": "", + "tags": [], + }, + ] + events = await _collect_events(with_callbacks.astream_events({}, version="v2")) # type: ignore + assert events == [ + { + "data": {"input": {}}, + "event": "on_tool_start", + "metadata": {}, + "name": "with_callbacks", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "world"}, + "event": "on_tool_end", + "metadata": {}, + "name": "with_callbacks", + "run_id": "", + "tags": [], + }, + ] + events = await _collect_events( + with_parameters.astream_events({"x": 1, "y": "2"}, version="v2") # type: ignore + ) + assert events == [ + { + "data": {"input": {"x": 1, "y": "2"}}, + "event": "on_tool_start", + "metadata": {}, + "name": "with_parameters", + "run_id": "", + "tags": [], + }, + { + "data": {"output": {"x": 1, "y": "2"}}, + "event": "on_tool_end", + "metadata": {}, + "name": "with_parameters", + "run_id": "", + "tags": [], + }, + ] + + events = await _collect_events( + with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}, version="v2") # type: ignore + ) + assert events == [ + { + "data": {"input": {"x": 1, "y": "2"}}, + "event": "on_tool_start", + "metadata": {}, + "name": "with_parameters_and_callbacks", + "run_id": "", + "tags": [], + }, + { + "data": {"output": {"x": 1, "y": "2"}}, + "event": "on_tool_end", + "metadata": {}, + "name": "with_parameters_and_callbacks", + "run_id": "", + "tags": [], + }, + ] + + +class HardCodedRetriever(BaseRetriever): + documents: List[Document] + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + return self.documents + + +async def test_event_stream_with_retriever() -> None: + """Test the event stream with a retriever.""" + retriever = HardCodedRetriever( + documents=[ + Document( + page_content="hello world!", + metadata={"foo": "bar"}, + ), + Document( + page_content="goodbye world!", + metadata={"food": "spare"}, + ), + ] + ) + events = await _collect_events( + retriever.astream_events({"query": "hello"}, version="v2") + ) + assert events == [ + { + "data": { + "input": {"query": "hello"}, + }, + "event": "on_retriever_start", + "metadata": {}, + "name": "HardCodedRetriever", + "run_id": "", + "tags": [], + }, + { + "data": { + "output": [ + Document(page_content="hello world!", metadata={"foo": "bar"}), + Document(page_content="goodbye world!", metadata={"food": "spare"}), + ] + }, + "event": "on_retriever_end", + "metadata": {}, + "name": "HardCodedRetriever", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_retriever_and_formatter() -> None: + """Test the event stream with a retriever.""" + retriever = HardCodedRetriever( + documents=[ + Document( + page_content="hello world!", + metadata={"foo": "bar"}, + ), + Document( + page_content="goodbye world!", + metadata={"food": "spare"}, + ), + ] + ) + + def format_docs(docs: List[Document]) -> str: + """Format the docs.""" + return ", ".join([doc.page_content for doc in docs]) + + chain = retriever | format_docs + events = await _collect_events(chain.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"query": "hello"}}, + "event": "on_retriever_start", + "metadata": {}, + "name": "HardCodedRetriever", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": { + "input": {"query": "hello"}, + "output": [ + Document(page_content="hello world!", metadata={"foo": "bar"}), + Document(page_content="goodbye world!", metadata={"food": "spare"}), + ], + }, + "event": "on_retriever_end", + "metadata": {}, + "name": "HardCodedRetriever", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "format_docs", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "hello world!, goodbye world!"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "format_docs", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "hello world!, goodbye world!"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": { + "input": [ + Document(page_content="hello world!", metadata={"foo": "bar"}), + Document(page_content="goodbye world!", metadata={"food": "spare"}), + ], + "output": "hello world!, goodbye world!", + }, + "event": "on_chain_end", + "metadata": {}, + "name": "format_docs", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"output": "hello world!, goodbye world!"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_on_chain_with_tool() -> None: + """Test the event stream with a tool.""" + + @tool + def concat(a: str, b: str) -> str: + """A tool that does nothing.""" + return a + b + + def reverse(s: str) -> str: + """Reverse a string.""" + return s[::-1] + + # For whatever reason type annotations fail here because reverse + # does not appear to be a runnable + chain = concat | reverse # type: ignore + + events = await _collect_events( + chain.astream_events({"a": "hello", "b": "world"}, version="v2") + ) + assert events == [ + { + "data": {"input": {"a": "hello", "b": "world"}}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"a": "hello", "b": "world"}}, + "event": "on_tool_start", + "metadata": {}, + "name": "concat", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"input": {"a": "hello", "b": "world"}, "output": "helloworld"}, + "event": "on_tool_end", + "metadata": {}, + "name": "concat", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "dlrowolleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "dlrowolleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": "helloworld", "output": "dlrowolleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"output": "dlrowolleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +@pytest.mark.xfail(reason="Fix order of callback invocations in RunnableSequence") +async def test_chain_ordering() -> None: + """Test the event stream with a tool.""" + + def foo(a: str) -> str: + return a + + def bar(a: str) -> str: + return a + + chain = RunnableLambda(foo) | RunnableLambda(bar) + iterable = chain.astream_events("q", version="v2") + + events = [] + + for _ in range(10): + try: + next_chunk = await iterable.__anext__() + events.append(next_chunk) + except Exception: + break + + events = _with_nulled_run_id(events) + for event in events: + event["tags"] = sorted(event["tags"]) + + assert events == [ + { + "data": {"input": "q"}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "q"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"input": "q", "output": "q"}, + "event": "on_chain_end", + "metadata": {}, + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "bar", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "q"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "bar", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "q"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": "q", "output": "q"}, + "event": "on_chain_end", + "metadata": {}, + "name": "bar", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"output": "q"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_retry() -> None: + """Test the event stream with a tool.""" + + def success(inputs: str) -> str: + return "success" + + def fail(inputs: str) -> None: + """Simple func.""" + raise Exception("fail") + + chain = RunnableLambda(success) | RunnableLambda(fail).with_retry( + stop_after_attempt=1, + ) + iterable = chain.astream_events("q", version="v2") + + events = [] + + for _ in range(10): + try: + next_chunk = await iterable.__anext__() + events.append(next_chunk) + except Exception: + break + + events = _with_nulled_run_id(events) + for event in events: + event["tags"] = sorted(event["tags"]) + + assert events == [ + { + "data": {"input": "q"}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "success", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "success"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "success", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "fail", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"input": "q", "output": "success"}, + "event": "on_chain_end", + "metadata": {}, + "name": "success", + "run_id": "", + "tags": ["seq:step:1"], + }, + ] + + +async def test_with_llm() -> None: + """Test with regular llm.""" + prompt = ChatPromptTemplate.from_messages( + [("system", "You are Cat Agent 007"), ("human", "{question}")] + ).with_config({"run_name": "my_template", "tags": ["my_template"]}) + llm = FakeStreamingListLLM(responses=["abc"]) + + chain = prompt | llm + events = await _collect_events( + chain.astream_events({"question": "hello"}, version="v2") + ) + assert events == [ + { + "data": {"input": {"question": "hello"}}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"question": "hello"}}, + "event": "on_prompt_start", + "metadata": {}, + "name": "my_template", + "run_id": "", + "tags": ["my_template", "seq:step:1"], + }, + { + "data": { + "input": {"question": "hello"}, + "output": ChatPromptValue( + messages=[ + SystemMessage(content="You are Cat Agent 007"), + HumanMessage(content="hello"), + ] + ), + }, + "event": "on_prompt_end", + "metadata": {}, + "name": "my_template", + "run_id": "", + "tags": ["my_template", "seq:step:1"], + }, + { + "data": { + "input": {"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]} + }, + "event": "on_llm_start", + "metadata": {}, + "name": "FakeStreamingListLLM", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": { + "input": { + "prompts": ["System: You are Cat Agent 007\n" "Human: hello"] + }, + "output": { + "generations": [ + [{"generation_info": None, "text": "abc", "type": "Generation"}] + ], + "llm_output": None, + }, + }, + "event": "on_llm_end", + "metadata": {}, + "name": "FakeStreamingListLLM", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "a"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "b"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "c"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "abc"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +async def test_runnable_each() -> None: + """Test runnable each astream_events.""" + + async def add_one(x: int) -> int: + return x + 1 + + add_one_map = RunnableLambda(add_one).map() # type: ignore + assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4] + + with pytest.raises(NotImplementedError): + async for _ in add_one_map.astream_events([1, 2, 3], version="v2"): + pass + + +async def test_events_astream_config() -> None: + """Test that astream events support accepting config""" + infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")]) + good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")]) + model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields( + messages=ConfigurableField( + id="messages", + name="Messages", + description="Messages return by the LLM", + ) + ) + + model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}}) + assert model_02.invoke("hello") == AIMessage(content="Goodbye world", id="ai2") + + events = await _collect_events(model_02.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chat_model_start", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content="Goodbye", id="ai2")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id="ai2")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content="world", id="ai2")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + { + "data": { + "output": AIMessageChunk(content="Goodbye world", id="ai2"), + }, + "event": "on_chat_model_end", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + ] + + +async def test_runnable_with_message_history() -> None: + class InMemoryHistory(BaseChatMessageHistory, BaseModel): + """In memory implementation of chat message history.""" + + # Attention: for the tests use an Any type to work-around a pydantic issue + # where it re-instantiates a list, so mutating the list doesn't end up mutating + # the content in the store! + + # Using Any type here rather than List[BaseMessage] due to pydantic issue! + messages: Any + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store.""" + self.messages.append(message) + + def clear(self) -> None: + self.messages = [] + + # Here we use a global variable to store the chat message history. + # This will make it easier to inspect it to see the underlying results. + store: Dict = {} + + def get_by_session_id(session_id: str) -> BaseChatMessageHistory: + """Get a chat message history""" + if session_id not in store: + store[session_id] = [] + return InMemoryHistory(messages=store[session_id]) + + infinite_cycle = cycle( + [AIMessage(content="hello", id="ai3"), AIMessage(content="world", id="ai4")] + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a cat"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + + chain: Runnable = prompt | model + with_message_history = RunnableWithMessageHistory( + chain, + get_session_history=get_by_session_id, + input_messages_key="question", + history_messages_key="history", + ) + await with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).ainvoke({"question": "hello"}) + + assert store == { + "session-123": [ + HumanMessage(content="hello"), + AIMessage(content="hello", id="ai3"), + ] + } + + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke({"question": "meow"}) + assert store == { + "session-123": [ + HumanMessage(content="hello"), + AIMessage(content="hello", id="ai3"), + HumanMessage(content="meow"), + AIMessage(content="world", id="ai4"), + ] + } + + +EXPECTED_EVENTS = [ + { + "data": {"input": 1}, + "event": "on_chain_start", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": 2}, + "event": "on_chain_stream", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"input": 1, "output": 2}, + "event": "on_chain_end", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": 2}, + "event": "on_chain_stream", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, + { + "data": {"output": 2}, + "event": "on_chain_end", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, +] + + +@pytest.mark.xfail( + reason="This test is failing due to missing functionality." + "Need to implement logic in _transform_stream_with_config that mimics the async " + "variant that uses tap_output_iter" +) +async def test_sync_in_async_stream_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + def add_one_(x: int) -> int: + return x + 1 + + add_one = RunnableLambda(add_one_) + + async def add_one_proxy_(x: int, config: RunnableConfig) -> int: + streaming = add_one.stream(x, config) + results = [result for result in streaming] + return results[0] + + add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore + + events = await _collect_events(add_one_proxy.astream_events(1, version="v2")) + assert events == EXPECTED_EVENTS + + +async def test_async_in_async_stream_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + async def add_one(x: int) -> int: + return x + 1 + + add_one_ = RunnableLambda(add_one) # type: ignore + + async def add_one_proxy(x: int, config: RunnableConfig) -> int: + # Use sync streaming + streaming = add_one_.astream(x, config) + results = [result async for result in streaming] + return results[0] + + add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore + + events = await _collect_events(add_one_proxy_.astream_events(1, version="v2")) + assert events == EXPECTED_EVENTS + + +@pytest.mark.xfail( + reason="This test is failing due to missing functionality." + "Need to implement logic in _transform_stream_with_config that mimics the async " + "variant that uses tap_output_iter" +) +async def test_sync_in_sync_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + def add_one(x: int) -> int: + return x + 1 + + add_one_ = RunnableLambda(add_one) + + def add_one_proxy(x: int, config: RunnableConfig) -> int: + # Use sync streaming + streaming = add_one_.stream(x, config) + results = [result for result in streaming] + return results[0] + + add_one_proxy_ = RunnableLambda(add_one_proxy) + + events = await _collect_events(add_one_proxy_.astream_events(1, version="v2")) + assert events == EXPECTED_EVENTS