diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 6abd5880553..f6967f51c09 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -52,7 +52,7 @@ from langchain_core.outputs import ( from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.runnables.config import ensure_config, run_in_executor -from langchain_core.tracers.log_stream import LogStreamCallbackHandler +from langchain_core.tracers._streaming import _StreamingCallbackHandler if TYPE_CHECKING: from langchain_core.pydantic_v1 import BaseModel @@ -608,7 +608,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ( True for h in run_manager.handlers - if isinstance(h, LogStreamCallbackHandler) + if isinstance(h, _StreamingCallbackHandler) ), False, ) @@ -691,7 +691,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ( True for h in run_manager.handlers - if isinstance(h, LogStreamCallbackHandler) + if isinstance(h, _StreamingCallbackHandler) ), False, ) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index def410ba693..c2dfa420183 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -58,7 +58,7 @@ from langchain_core.runnables.config import ( var_child_runnable_config, ) from langchain_core.runnables.graph import Graph -from langchain_core.runnables.schema import EventData, StreamEvent +from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.utils import ( AddableDict, AnyConfigurableField, @@ -90,7 +90,6 @@ if TYPE_CHECKING: RunnableWithFallbacks as RunnableWithFallbacksT, ) from langchain_core.tracers.log_stream import ( - LogEntry, RunLog, RunLogPatch, ) @@ -927,7 +926,7 @@ class Runnable(Generic[Input, Output], ABC): input: Any, config: Optional[RunnableConfig] = None, *, - version: Literal["v1"], + version: Literal["v1", "v2"], include_names: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None, @@ -962,6 +961,8 @@ class Runnable(Generic[Input, Output], ABC): chains. Metadata fields have been omitted from the table for brevity. Chain definitions have been included after the table. + **ATTENTION** This reference table is for the V2 version of the schema. + +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | event | name | chunk | input | output | +======================+==================+=================================+===============================================+=================================================+ @@ -969,7 +970,7 @@ class Runnable(Generic[Input, Output], ABC): +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_chat_model_stream | [model name] | AIMessageChunk(content="hello") | | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ - | on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | {"generations": [...], "llm_output": None, ...} | + | on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | AIMessageChunk(content="hello world") | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_llm_start | [model name] | | {'input': 'hello'} | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ @@ -985,15 +986,11 @@ class Runnable(Generic[Input, Output], ABC): +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_tool_start | some_tool | | {"x": 1, "y": "2"} | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ - | on_tool_stream | some_tool | {"x": 1, "y": "2"} | | | - +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_tool_end | some_tool | | | {"x": 1, "y": "2"} | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_retriever_start | [retriever name] | | {"query": "hello"} | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ - | on_retriever_chunk | [retriever name] | {documents: [...]} | | | - +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ - | on_retriever_end | [retriever name] | | {"query": "hello"} | {documents: [...]} | + | on_retriever_end | [retriever name] | | {"query": "hello"} | [Document(...), ..] | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ | on_prompt_start | [template_name] | | {"question": "hello"} | | +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+ @@ -1042,7 +1039,7 @@ class Runnable(Generic[Input, Output], ABC): chain = RunnableLambda(func=reverse) events = [ - event async for event in chain.astream_events("hello", version="v1") + event async for event in chain.astream_events("hello", version="v2") ] # will produce the following events (run_id has been omitted for brevity): @@ -1073,8 +1070,10 @@ class Runnable(Generic[Input, Output], ABC): Args: input: The input to the runnable. config: The config to use for the runnable. - version: The version of the schema to use. - Currently only version 1 is available. + version: The version of the schema to use either `v2` or `v1`. + Users should use `v2`. + `v1` is for backwards compatibility and will be deprecated + in 0.4.0. No default will be assigned until the API is stabilized. include_names: Only include events from runnables with matching names. include_types: Only include events from runnables with matching types. @@ -1089,180 +1088,45 @@ class Runnable(Generic[Input, Output], ABC): Returns: An async stream of StreamEvents. """ # noqa: E501 - if version != "v1": + from langchain_core.tracers.event_stream import ( + _astream_events_implementation_v1, + _astream_events_implementation_v2, + ) + + if version == "v2": + event_stream = _astream_events_implementation_v2( + self, + input, + config=config, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + **kwargs, + ) + elif version == "v1": + # First implementation, built on top of astream_log API + # This implementation will be deprecated as of 0.2.0 + event_stream = _astream_events_implementation_v1( + self, + input, + config=config, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + **kwargs, + ) + else: raise NotImplementedError( - 'Only version "v1" of the schema is currently supported.' + 'Only versions "v1" and "v2" of the schema is currently supported.' ) - from langchain_core.runnables.utils import ( - _RootEventFilter, - ) - from langchain_core.tracers.log_stream import ( - LogStreamCallbackHandler, - RunLog, - _astream_log_implementation, - ) - - stream = LogStreamCallbackHandler( - auto_close=False, - include_names=include_names, - include_types=include_types, - include_tags=include_tags, - exclude_names=exclude_names, - exclude_types=exclude_types, - exclude_tags=exclude_tags, - _schema_format="streaming_events", - ) - - run_log = RunLog(state=None) # type: ignore[arg-type] - encountered_start_event = False - - _root_event_filter = _RootEventFilter( - include_names=include_names, - include_types=include_types, - include_tags=include_tags, - exclude_names=exclude_names, - exclude_types=exclude_types, - exclude_tags=exclude_tags, - ) - - config = ensure_config(config) - root_tags = config.get("tags", []) - root_metadata = config.get("metadata", {}) - root_name = config.get("run_name", self.get_name()) - - # Ignoring mypy complaint about too many different union combinations - # This arises because many of the argument types are unions - async for log in _astream_log_implementation( # type: ignore[misc] - self, - input, - config=config, - stream=stream, - diff=True, - with_streamed_output_list=True, - **kwargs, - ): - run_log = run_log + log - - if not encountered_start_event: - # Yield the start event for the root runnable. - encountered_start_event = True - state = run_log.state.copy() - - event = StreamEvent( - event=f"on_{state['type']}_start", - run_id=state["id"], - name=root_name, - tags=root_tags, - metadata=root_metadata, - data={ - "input": input, - }, - ) - - if _root_event_filter.include_event(event, state["type"]): - yield event - - paths = { - op["path"].split("/")[2] - for op in log.ops - if op["path"].startswith("/logs/") - } - # Elements in a set should be iterated in the same order - # as they were inserted in modern python versions. - for path in paths: - data: EventData = {} - log_entry: LogEntry = run_log.state["logs"][path] - if log_entry["end_time"] is None: - if log_entry["streamed_output"]: - event_type = "stream" - else: - event_type = "start" - else: - event_type = "end" - - if event_type == "start": - # Include the inputs with the start event if they are available. - # Usually they will NOT be available for components that operate - # on streams, since those components stream the input and - # don't know its final value until the end of the stream. - inputs = log_entry["inputs"] - if inputs is not None: - data["input"] = inputs - pass - - if event_type == "end": - inputs = log_entry["inputs"] - if inputs is not None: - data["input"] = inputs - - # None is a VALID output for an end event - data["output"] = log_entry["final_output"] - - if event_type == "stream": - num_chunks = len(log_entry["streamed_output"]) - if num_chunks != 1: - raise AssertionError( - f"Expected exactly one chunk of streamed output, " - f"got {num_chunks} instead. This is impossible. " - f"Encountered in: {log_entry['name']}" - ) - - data = {"chunk": log_entry["streamed_output"][0]} - # Clean up the stream, we don't need it anymore. - # And this avoids duplicates as well! - log_entry["streamed_output"] = [] - - yield StreamEvent( - event=f"on_{log_entry['type']}_{event_type}", - name=log_entry["name"], - run_id=log_entry["id"], - tags=log_entry["tags"], - metadata=log_entry["metadata"], - data=data, - ) - - # Finally, we take care of the streaming output from the root chain - # if there is any. - state = run_log.state - if state["streamed_output"]: - num_chunks = len(state["streamed_output"]) - if num_chunks != 1: - raise AssertionError( - f"Expected exactly one chunk of streamed output, " - f"got {num_chunks} instead. This is impossible. " - f"Encountered in: {state['name']}" - ) - - data = {"chunk": state["streamed_output"][0]} - # Clean up the stream, we don't need it anymore. - state["streamed_output"] = [] - - event = StreamEvent( - event=f"on_{state['type']}_stream", - run_id=state["id"], - tags=root_tags, - metadata=root_metadata, - name=root_name, - data=data, - ) - if _root_event_filter.include_event(event, state["type"]): - yield event - - state = run_log.state - - # Finally yield the end event for the root runnable. - event = StreamEvent( - event=f"on_{state['type']}_end", - name=root_name, - run_id=state["id"], - tags=root_tags, - metadata=root_metadata, - data={ - "output": state["final_output"], - }, - ) - if _root_event_filter.include_event(event, state["type"]): + async for event in event_stream: yield event def transform( @@ -1936,7 +1800,8 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" - from langchain_core.tracers.log_stream import LogStreamCallbackHandler + # Mixin that is used by both astream log and astream events implementation + from langchain_core.tracers._streaming import _StreamingCallbackHandler # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = atee(input, 2) @@ -1964,16 +1829,18 @@ class Runnable(Generic[Input, Output], ABC): context = copy_context() context.run(var_child_runnable_config.set, child_config) iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type] - if stream_log := next( + + if stream_handler := next( ( - h + cast(_StreamingCallbackHandler, h) for h in run_manager.handlers - if isinstance(h, LogStreamCallbackHandler) + # instance check OK here, it's a mixin + if isinstance(h, _StreamingCallbackHandler) # type: ignore[misc] ), None, ): # populates streamed_output in astream_log() output if needed - iterator = stream_log.tap_output_aiter(run_manager.run_id, iterator) + iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator) try: while True: if accepts_context(asyncio.create_task): diff --git a/libs/core/langchain_core/tracers/_streaming.py b/libs/core/langchain_core/tracers/_streaming.py new file mode 100644 index 00000000000..8c778daffa5 --- /dev/null +++ b/libs/core/langchain_core/tracers/_streaming.py @@ -0,0 +1,28 @@ +"""Internal tracers used for stream_log and astream events implementations.""" +import abc +from typing import AsyncIterator, TypeVar +from uuid import UUID + +T = TypeVar("T") + + +class _StreamingCallbackHandler(abc.ABC): + """For internal use. + + This is a common mixin that the callback handlers + for both astream events and astream log inherit from. + + The `tap_output_aiter` method is invoked in some contexts + to produce callbacks for intermediate results. + """ + + @abc.abstractmethod + def tap_output_aiter( + self, run_id: UUID, output: AsyncIterator[T] + ) -> AsyncIterator[T]: + """Used for internal astream_log and astream events implementations.""" + + +__all__ = [ + "_StreamingCallbackHandler", +] diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py new file mode 100644 index 00000000000..38fee911364 --- /dev/null +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -0,0 +1,790 @@ +"""Internal tracer to power the event stream API.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + List, + Optional, + Sequence, + TypeVar, + Union, + cast, +) +from uuid import UUID + +from typing_extensions import NotRequired, TypedDict + +from langchain_core.callbacks.base import AsyncCallbackHandler +from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk +from langchain_core.outputs import ( + ChatGenerationChunk, + GenerationChunk, + LLMResult, +) +from langchain_core.runnables.schema import EventData, StreamEvent +from langchain_core.runnables.utils import ( + Input, + Output, + _RootEventFilter, +) +from langchain_core.tracers._streaming import _StreamingCallbackHandler +from langchain_core.tracers.log_stream import LogEntry +from langchain_core.tracers.memory_stream import _MemoryStream + +if TYPE_CHECKING: + from langchain_core.documents import Document + from langchain_core.runnables import Runnable, RunnableConfig + +logger = logging.getLogger(__name__) + + +class RunInfo(TypedDict): + """Information about a run.""" + + name: str + tags: List[str] + metadata: Dict[str, Any] + run_type: str + inputs: NotRequired[Any] + + +def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> str: + """Assign a name to a run.""" + if name is not None: + return name + if "name" in serialized: + return serialized["name"] + elif "id" in serialized: + return serialized["id"][-1] + return "Unnamed" + + +T = TypeVar("T") + + +class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHandler): + """An implementation of an async callback handler for astream events.""" + + def __init__( + self, + *args: Any, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> None: + """Initialize the tracer.""" + super().__init__(*args, **kwargs) + # Map of run ID to run info. + self.run_map: Dict[UUID, RunInfo] = {} + + # Filter which events will be sent over the queue. + self.root_event_filter = _RootEventFilter( + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + ) + + loop = asyncio.get_event_loop() + memory_stream = _MemoryStream[StreamEvent](loop) + self.send_stream = memory_stream.get_send_stream() + self.receive_stream = memory_stream.get_receive_stream() + + async def _send(self, event: StreamEvent, event_type: str) -> None: + """Send an event to the stream.""" + if self.root_event_filter.include_event(event, event_type): + await self.send_stream.send(event) + + def __aiter__(self) -> AsyncIterator[Any]: + """Iterate over the receive stream.""" + return self.receive_stream.__aiter__() + + async def tap_output_aiter( + self, run_id: UUID, output: AsyncIterator[T] + ) -> AsyncIterator[T]: + """Tap the output aiter.""" + async for chunk in output: + run_info = self.run_map.get(run_id) + if run_info is None: + raise AssertionError(f"Run ID {run_id} not found in run map.") + await self._send( + { + "event": f"on_{run_info['run_type']}_stream", + "data": {"chunk": chunk}, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_info["run_type"], + ) + yield chunk + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Start a trace for an LLM run.""" + name_ = _assign_name(name, serialized) + run_type = "chat_model" + self.run_map[run_id] = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": run_type, + "inputs": {"messages": messages}, + } + + await self._send( + { + "event": "on_chat_model_start", + "data": { + "input": {"messages": messages}, + }, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + run_type, + ) + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Start a trace for an LLM run.""" + name_ = _assign_name(name, serialized) + run_type = "llm" + self.run_map[run_id] = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": run_type, + "inputs": {"prompts": prompts}, + } + + await self._send( + { + "event": "on_llm_start", + "data": { + "input": { + "prompts": prompts, + } + }, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + run_type, + ) + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + run_info = self.run_map.get(run_id) + + chunk_: Union[GenerationChunk, BaseMessageChunk] + + if run_info is None: + raise AssertionError(f"Run ID {run_id} not found in run map.") + if run_info["run_type"] == "chat_model": + event = "on_chat_model_stream" + + if chunk is None: + chunk_ = AIMessageChunk(content=token) + else: + chunk_ = cast(ChatGenerationChunk, chunk).message + + elif run_info["run_type"] == "llm": + event = "on_llm_stream" + if chunk is None: + chunk_ = GenerationChunk(text=token) + else: + chunk_ = cast(GenerationChunk, chunk) + else: + raise ValueError(f"Unexpected run type: {run_info['run_type']}") + + await self._send( + { + "event": event, + "data": { + "chunk": chunk_, + }, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_info["run_type"], + ) + + async def on_llm_end( + self, response: LLMResult, *, run_id: UUID, **kwargs: Any + ) -> None: + """End a trace for an LLM run.""" + run_info = self.run_map.pop(run_id) + inputs_ = run_info["inputs"] + + generations: Union[List[List[GenerationChunk]], List[List[ChatGenerationChunk]]] + output: Union[dict, BaseMessage] = {} + + if run_info["run_type"] == "chat_model": + generations = cast(List[List[ChatGenerationChunk]], response.generations) + for gen in generations: + if output != {}: + break + for chunk in gen: + output = chunk.message + break + + event = "on_chat_model_end" + elif run_info["run_type"] == "llm": + generations = cast(List[List[GenerationChunk]], response.generations) + output = { + "generations": [ + [ + { + "text": chunk.text, + "generation_info": chunk.generation_info, + "type": chunk.type, + } + for chunk in gen + ] + for gen in generations + ], + "llm_output": response.llm_output, + } + event = "on_llm_end" + else: + raise ValueError(f"Unexpected run type: {run_info['run_type']}") + + await self._send( + { + "event": event, + "data": {"output": output, "input": inputs_}, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_info["run_type"], + ) + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + run_type: Optional[str] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Start a trace for a chain run.""" + name_ = _assign_name(name, serialized) + run_type_ = run_type or "chain" + run_info: RunInfo = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": run_type_, + } + + data: EventData = {} + + # Work-around Runnable core code not sending input in some + # cases. + if inputs != {"input": ""}: + data["input"] = inputs + run_info["inputs"] = inputs + + self.run_map[run_id] = run_info + + await self._send( + { + "event": f"on_{run_type_}_start", + "data": data, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + run_type_, + ) + + async def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + inputs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """End a trace for a chain run.""" + run_info = self.run_map.pop(run_id) + run_type = run_info["run_type"] + + event = f"on_{run_type}_end" + + inputs = inputs or run_info.get("inputs") or {} + + data: EventData = { + "output": outputs, + "input": inputs, + } + + await self._send( + { + "event": event, + "data": data, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_type, + ) + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + inputs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Start a trace for a tool run.""" + name_ = _assign_name(name, serialized) + self.run_map[run_id] = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": "tool", + "inputs": inputs, + } + + await self._send( + { + "event": "on_tool_start", + "data": { + "input": inputs or {}, + }, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + "tool", + ) + + async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: + """End a trace for a tool run.""" + run_info = self.run_map.pop(run_id) + if "inputs" not in run_info: + raise AssertionError( + f"Run ID {run_id} is a tool call and is expected to have " + f"inputs associated with it." + ) + inputs = run_info["inputs"] + + await self._send( + { + "event": "on_tool_end", + "data": { + "output": output, + "input": inputs, + }, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + "tool", + ) + + async def on_retriever_start( + self, + serialized: Dict[str, Any], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Run when Retriever starts running.""" + name_ = _assign_name(name, serialized) + run_type = "retriever" + self.run_map[run_id] = { + "tags": tags or [], + "metadata": metadata or {}, + "name": name_, + "run_type": run_type, + "inputs": {"query": query}, + } + + await self._send( + { + "event": "on_retriever_start", + "data": { + "input": { + "query": query, + } + }, + "name": name_, + "tags": tags or [], + "run_id": str(run_id), + "metadata": metadata or {}, + }, + run_type, + ) + + async def on_retriever_end( + self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any + ) -> None: + """Run when Retriever ends running.""" + run_info = self.run_map.pop(run_id) + + await self._send( + { + "event": "on_retriever_end", + "data": { + "output": documents, + "input": run_info["inputs"], + }, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + }, + run_info["run_type"], + ) + + def __deepcopy__(self, memo: dict) -> _AstreamEventsCallbackHandler: + """Deepcopy the tracer.""" + return self + + def __copy__(self) -> _AstreamEventsCallbackHandler: + """Copy the tracer.""" + return self + + +async def _astream_events_implementation_v1( + runnable: Runnable[Input, Output], + input: Any, + config: Optional[RunnableConfig] = None, + *, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, +) -> AsyncIterator[StreamEvent]: + from langchain_core.runnables import ensure_config + from langchain_core.runnables.utils import _RootEventFilter + from langchain_core.tracers.log_stream import ( + LogStreamCallbackHandler, + RunLog, + _astream_log_implementation, + ) + + stream = LogStreamCallbackHandler( + auto_close=False, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + _schema_format="streaming_events", + ) + + run_log = RunLog(state=None) # type: ignore[arg-type] + encountered_start_event = False + + _root_event_filter = _RootEventFilter( + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + ) + + config = ensure_config(config) + root_tags = config.get("tags", []) + root_metadata = config.get("metadata", {}) + root_name = config.get("run_name", runnable.get_name()) + + # Ignoring mypy complaint about too many different union combinations + # This arises because many of the argument types are unions + async for log in _astream_log_implementation( # type: ignore[misc] + runnable, + input, + config=config, + stream=stream, + diff=True, + with_streamed_output_list=True, + **kwargs, + ): + run_log = run_log + log + + if not encountered_start_event: + # Yield the start event for the root runnable. + encountered_start_event = True + state = run_log.state.copy() + + event = StreamEvent( + event=f"on_{state['type']}_start", + run_id=state["id"], + name=root_name, + tags=root_tags, + metadata=root_metadata, + data={ + "input": input, + }, + ) + + if _root_event_filter.include_event(event, state["type"]): + yield event + + paths = { + op["path"].split("/")[2] + for op in log.ops + if op["path"].startswith("/logs/") + } + # Elements in a set should be iterated in the same order + # as they were inserted in modern python versions. + for path in paths: + data: EventData = {} + log_entry: LogEntry = run_log.state["logs"][path] + if log_entry["end_time"] is None: + if log_entry["streamed_output"]: + event_type = "stream" + else: + event_type = "start" + else: + event_type = "end" + + if event_type == "start": + # Include the inputs with the start event if they are available. + # Usually they will NOT be available for components that operate + # on streams, since those components stream the input and + # don't know its final value until the end of the stream. + inputs = log_entry["inputs"] + if inputs is not None: + data["input"] = inputs + pass + + if event_type == "end": + inputs = log_entry["inputs"] + if inputs is not None: + data["input"] = inputs + + # None is a VALID output for an end event + data["output"] = log_entry["final_output"] + + if event_type == "stream": + num_chunks = len(log_entry["streamed_output"]) + if num_chunks != 1: + raise AssertionError( + f"Expected exactly one chunk of streamed output, " + f"got {num_chunks} instead. This is impossible. " + f"Encountered in: {log_entry['name']}" + ) + + data = {"chunk": log_entry["streamed_output"][0]} + # Clean up the stream, we don't need it anymore. + # And this avoids duplicates as well! + log_entry["streamed_output"] = [] + + yield StreamEvent( + event=f"on_{log_entry['type']}_{event_type}", + name=log_entry["name"], + run_id=log_entry["id"], + tags=log_entry["tags"], + metadata=log_entry["metadata"], + data=data, + ) + + # Finally, we take care of the streaming output from the root chain + # if there is any. + state = run_log.state + if state["streamed_output"]: + num_chunks = len(state["streamed_output"]) + if num_chunks != 1: + raise AssertionError( + f"Expected exactly one chunk of streamed output, " + f"got {num_chunks} instead. This is impossible. " + f"Encountered in: {state['name']}" + ) + + data = {"chunk": state["streamed_output"][0]} + # Clean up the stream, we don't need it anymore. + state["streamed_output"] = [] + + event = StreamEvent( + event=f"on_{state['type']}_stream", + run_id=state["id"], + tags=root_tags, + metadata=root_metadata, + name=root_name, + data=data, + ) + if _root_event_filter.include_event(event, state["type"]): + yield event + + state = run_log.state + + # Finally yield the end event for the root runnable. + event = StreamEvent( + event=f"on_{state['type']}_end", + name=root_name, + run_id=state["id"], + tags=root_tags, + metadata=root_metadata, + data={ + "output": state["final_output"], + }, + ) + if _root_event_filter.include_event(event, state["type"]): + yield event + + +async def _astream_events_implementation_v2( + runnable: Runnable[Input, Output], + input: Any, + config: Optional[RunnableConfig] = None, + *, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Any, +) -> AsyncIterator[StreamEvent]: + """Implementation of the astream events API for V2 runnables.""" + from langchain_core.callbacks.base import BaseCallbackManager + from langchain_core.runnables import ensure_config + + event_streamer = _AstreamEventsCallbackHandler( + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + ) + + # Assign the stream handler to the config + config = ensure_config(config) + callbacks = config.get("callbacks") + if callbacks is None: + config["callbacks"] = [event_streamer] + elif isinstance(callbacks, list): + config["callbacks"] = callbacks + [event_streamer] + elif isinstance(callbacks, BaseCallbackManager): + callbacks = callbacks.copy() + callbacks.add_handler(event_streamer, inherit=True) + config["callbacks"] = callbacks + else: + raise ValueError( + f"Unexpected type for callbacks: {callbacks}." + "Expected None, list or AsyncCallbackManager." + ) + + # Call the runnable in streaming mode, + # add each chunk to the output stream + async def consume_astream() -> None: + try: + async for _ in runnable.astream(input, config, **kwargs): + # All the content will be picked up + pass + finally: + await event_streamer.send_stream.aclose() + + # Start the runnable in a task, so we can start consuming output + task = asyncio.create_task(consume_astream()) + + first_event_sent = False + first_event_run_id = None + + try: + async for event in event_streamer: + if not first_event_sent: + first_event_sent = True + # This is a work-around an issue where the inputs into the + # chain are not available until the entire input is consumed. + # As a temporary solution, we'll modify the input to be the input + # that was passed into the chain. + event["data"]["input"] = input + first_event_run_id = event["run_id"] + yield event + continue + + if event["run_id"] == first_event_run_id and event["event"].endswith( + "_end" + ): + # If it's the end event corresponding to the root runnable + # we dont include the input in the event since it's guaranteed + # to be included in the first event. + if "input" in event["data"]: + del event["data"]["input"] + + yield event + finally: + # Wait for the runnable to finish, if not cancelled (eg. by break) + try: + await task + except asyncio.CancelledError: + pass diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 4c02f933e42..47c789b6783 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -26,6 +26,7 @@ from langchain_core.load.load import load from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables.utils import Input, Output +from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.memory_stream import _MemoryStream from langchain_core.tracers.schemas import Run @@ -157,7 +158,7 @@ class RunLog(RunLogPatch): T = TypeVar("T") -class LogStreamCallbackHandler(BaseTracer): +class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): """Tracer that streams run logs to a stream.""" def __init__( diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index f233ba2d61d..011d482f331 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1198,7 +1198,7 @@ files = [ [[package]] name = "langchain-text-splitters" -version = "0.0.1" +version = "0.0.2" description = "LangChain text splitting utilities" optional = false python-versions = ">=3.8.1,<4.0" @@ -2142,7 +2142,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py similarity index 100% rename from libs/core/tests/unit_tests/runnables/test_runnable_events.py rename to libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py new file mode 100644 index 00000000000..47c69fe5f90 --- /dev/null +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -0,0 +1,1694 @@ +"""Module that contains tests for runnable.astream_events API.""" +import sys +from itertools import cycle +from typing import Any, AsyncIterator, Dict, List, Sequence, cast + +import pytest + +from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.documents import Document +from langchain_core.language_models import FakeStreamingListLLM, GenericFakeChatModel +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.prompt_values import ChatPromptValue +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables import ( + ConfigurableField, + Runnable, + RunnableConfig, + RunnableLambda, +) +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_core.runnables.schema import StreamEvent +from langchain_core.tools import tool +from tests.unit_tests.stubs import AnyStr + + +def _with_nulled_run_id(events: Sequence[StreamEvent]) -> List[StreamEvent]: + """Removes the run ids from events.""" + return cast(List[StreamEvent], [{**event, "run_id": ""} for event in events]) + + +async def _as_async_iterator(iterable: List) -> AsyncIterator: + """Converts an iterable into an async iterator.""" + for item in iterable: + yield item + + +async def _collect_events(events: AsyncIterator[StreamEvent]) -> List[StreamEvent]: + """Collect the events and remove the run ids.""" + materialized_events = [event async for event in events] + events_ = _with_nulled_run_id(materialized_events) + for event in events_: + event["tags"] = sorted(event["tags"]) + return events_ + + +async def test_event_stream_with_simple_function_tool() -> None: + """Test the event stream with a function and tool""" + + def foo(x: int) -> dict: + """Foo""" + return {"x": 5} + + @tool + def get_docs(x: int) -> List[Document]: + """Hello Doc""" + return [Document(page_content="hello")] + + chain = RunnableLambda(foo) | get_docs + events = await _collect_events(chain.astream_events({}, version="v2")) + assert events == [ + { + "event": "on_chain_start", + "run_id": "", + "name": "RunnableSequence", + "tags": [], + "metadata": {}, + "data": {"input": {}}, + }, + { + "event": "on_chain_start", + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + "metadata": {}, + "data": {}, + }, + { + "event": "on_chain_stream", + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + "metadata": {}, + "data": {"chunk": {"x": 5}}, + }, + { + "event": "on_chain_end", + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + "metadata": {}, + "data": {"input": {}, "output": {"x": 5}}, + }, + { + "event": "on_tool_start", + "name": "get_docs", + "run_id": "", + "tags": ["seq:step:2"], + "metadata": {}, + "data": {"input": {"x": 5}}, + }, + { + "event": "on_tool_end", + "name": "get_docs", + "run_id": "", + "tags": ["seq:step:2"], + "metadata": {}, + "data": {"input": {"x": 5}, "output": [Document(page_content="hello")]}, + }, + { + "event": "on_chain_stream", + "run_id": "", + "tags": [], + "metadata": {}, + "name": "RunnableSequence", + "data": {"chunk": [Document(page_content="hello")]}, + }, + { + "event": "on_chain_end", + "name": "RunnableSequence", + "run_id": "", + "tags": [], + "metadata": {}, + "data": {"output": [Document(page_content="hello")]}, + }, + ] + + +async def test_event_stream_with_single_lambda() -> None: + """Test the event stream with a tool.""" + + def reverse(s: str) -> str: + """Reverse a string.""" + return s[::-1] + + chain = RunnableLambda(func=reverse) + + events = await _collect_events(chain.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_triple_lambda() -> None: + def reverse(s: str) -> str: + """Reverse a string.""" + return s[::-1] + + r = RunnableLambda(func=reverse) + + chain = ( + r.with_config({"run_name": "1"}) + | r.with_config({"run_name": "2"}) + | r.with_config({"run_name": "3"}) + ) + events = await _collect_events(chain.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "2", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"input": "hello", "output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "hello"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "2", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["seq:step:3"], + }, + { + "data": {"input": "olleh", "output": "hello"}, + "event": "on_chain_end", + "metadata": {}, + "name": "2", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["seq:step:3"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": "hello", "output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["seq:step:3"], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_triple_lambda_test_filtering() -> None: + """Test filtering based on tags / names""" + + def reverse(s: str) -> str: + """Reverse a string.""" + return s[::-1] + + r = RunnableLambda(func=reverse) + + chain = ( + r.with_config({"run_name": "1"}) + | r.with_config({"run_name": "2", "tags": ["my_tag"]}) + | r.with_config({"run_name": "3", "tags": ["my_tag"]}) + ) + events = await _collect_events( + chain.astream_events("hello", include_names=["1"], version="v2") + ) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "1", + "run_id": "", + "tags": ["seq:step:1"], + }, + ] + + events = await _collect_events( + chain.astream_events( + "hello", include_tags=["my_tag"], exclude_names=["2"], version="v2" + ) + ) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["my_tag", "seq:step:3"], + }, + { + "data": {"chunk": "olleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["my_tag", "seq:step:3"], + }, + { + "data": {"output": "olleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "3", + "run_id": "", + "tags": ["my_tag", "seq:step:3"], + }, + ] + + +async def test_event_stream_with_lambdas_from_lambda() -> None: + as_lambdas = RunnableLambda(lambda x: {"answer": "goodbye"}).with_config( + {"run_name": "my_lambda"} + ) + events = await _collect_events( + as_lambdas.astream_events({"question": "hello"}, version="v2") + ) + assert events == [ + { + "data": {"input": {"question": "hello"}}, + "event": "on_chain_start", + "metadata": {}, + "name": "my_lambda", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": {"answer": "goodbye"}}, + "event": "on_chain_stream", + "metadata": {}, + "name": "my_lambda", + "run_id": "", + "tags": [], + }, + { + "data": {"output": {"answer": "goodbye"}}, + "event": "on_chain_end", + "metadata": {}, + "name": "my_lambda", + "run_id": "", + "tags": [], + }, + ] + + +async def test_astream_events_from_model() -> None: + """Test the output of a model.""" + infinite_cycle = cycle([AIMessage(content="hello world!")]) + # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces + model = ( + GenericFakeChatModel(messages=infinite_cycle) + .with_config( + { + "metadata": {"a": "b"}, + "tags": ["my_model"], + "run_name": "my_model", + } + ) + .bind(stop="") + ) + events = await _collect_events(model.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "output": AIMessageChunk(content="hello world!", id=AnyStr()), + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + ] + + +async def test_astream_with_model_in_chain() -> None: + """Scenarios with model when it is not the only runnable in the chain.""" + infinite_cycle = cycle([AIMessage(content="hello world!")]) + # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces + model = ( + GenericFakeChatModel(messages=infinite_cycle) + .with_config( + { + "metadata": {"a": "b"}, + "tags": ["my_model"], + "run_name": "my_model", + } + ) + .bind(stop="") + ) + + @RunnableLambda + def i_dont_stream(input: Any, config: RunnableConfig) -> Any: + if sys.version_info >= (3, 11): + return model.invoke(input) + else: + return model.invoke(input, config) + + events = await _collect_events(i_dont_stream.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"messages": [[HumanMessage(content="hello")]]}}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "input": {"messages": [[HumanMessage(content="hello")]]}, + "output": AIMessage(content="hello world!", id=AnyStr()), + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessage(content="hello world!", id=AnyStr())}, + "event": "on_chain_stream", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessage(content="hello world!", id=AnyStr())}, + "event": "on_chain_end", + "metadata": {}, + "name": "i_dont_stream", + "run_id": "", + "tags": [], + }, + ] + + @RunnableLambda + async def ai_dont_stream(input: Any, config: RunnableConfig) -> Any: + if sys.version_info >= (3, 11): + return await model.ainvoke(input) + else: + return await model.ainvoke(input, config) + + events = await _collect_events(ai_dont_stream.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"messages": [[HumanMessage(content="hello")]]}}, + "event": "on_chat_model_start", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id=AnyStr())}, + "event": "on_chat_model_stream", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": { + "input": {"messages": [[HumanMessage(content="hello")]]}, + "output": AIMessage(content="hello world!", id=AnyStr()), + }, + "event": "on_chat_model_end", + "metadata": {"a": "b"}, + "name": "my_model", + "run_id": "", + "tags": ["my_model"], + }, + { + "data": {"chunk": AIMessage(content="hello world!", id=AnyStr())}, + "event": "on_chain_stream", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + { + "data": {"output": AIMessage(content="hello world!", id=AnyStr())}, + "event": "on_chain_end", + "metadata": {}, + "name": "ai_dont_stream", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_simple_chain() -> None: + """Test as event stream.""" + template = ChatPromptTemplate.from_messages( + [("system", "You are Cat Agent 007"), ("human", "{question}")] + ).with_config({"run_name": "my_template", "tags": ["my_template"]}) + + infinite_cycle = cycle( + [ + AIMessage(content="hello world!", id="ai1"), + AIMessage(content="goodbye world!", id="ai2"), + ] + ) + # When streaming GenericFakeChatModel breaks AIMessage into chunks based on spaces + model = ( + GenericFakeChatModel(messages=infinite_cycle) + .with_config( + { + "metadata": {"a": "b"}, + "tags": ["my_model"], + "run_name": "my_model", + } + ) + .bind(stop="") + ) + + chain = (template | model).with_config( + { + "metadata": {"foo": "bar"}, + "tags": ["my_chain"], + "run_name": "my_chain", + } + ) + + events = await _collect_events( + chain.astream_events({"question": "hello"}, version="v2") + ) + assert events == [ + { + "data": {"input": {"question": "hello"}}, + "event": "on_chain_start", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + { + "data": {"input": {"question": "hello"}}, + "event": "on_prompt_start", + "metadata": {"foo": "bar"}, + "name": "my_template", + "run_id": "", + "tags": ["my_chain", "my_template", "seq:step:1"], + }, + { + "data": { + "input": {"question": "hello"}, + "output": ChatPromptValue( + messages=[ + SystemMessage(content="You are Cat Agent 007"), + HumanMessage(content="hello"), + ] + ), + }, + "event": "on_prompt_end", + "metadata": {"foo": "bar"}, + "name": "my_template", + "run_id": "", + "tags": ["my_chain", "my_template", "seq:step:1"], + }, + { + "data": { + "input": { + "messages": [ + [ + SystemMessage(content="You are Cat Agent 007"), + HumanMessage(content="hello"), + ] + ] + } + }, + "event": "on_chat_model_start", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id="ai1")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"chunk": AIMessageChunk(content="hello", id="ai1")}, + "event": "on_chain_stream", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id="ai1")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id="ai1")}, + "event": "on_chain_stream", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id="ai1")}, + "event": "on_chat_model_stream", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"chunk": AIMessageChunk(content="world!", id="ai1")}, + "event": "on_chain_stream", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + { + "data": { + "input": { + "messages": [ + [ + SystemMessage(content="You are Cat Agent 007"), + HumanMessage(content="hello"), + ] + ] + }, + "output": AIMessageChunk(content="hello world!", id="ai1"), + }, + "event": "on_chat_model_end", + "metadata": {"a": "b", "foo": "bar"}, + "name": "my_model", + "run_id": "", + "tags": ["my_chain", "my_model", "seq:step:2"], + }, + { + "data": {"output": AIMessageChunk(content="hello world!", id="ai1")}, + "event": "on_chain_end", + "metadata": {"foo": "bar"}, + "name": "my_chain", + "run_id": "", + "tags": ["my_chain"], + }, + ] + + +async def test_event_streaming_with_tools() -> None: + """Test streaming events with different tool definitions.""" + + @tool + def parameterless() -> str: + """A tool that does nothing.""" + return "hello" + + @tool + def with_callbacks(callbacks: Callbacks) -> str: + """A tool that does nothing.""" + return "world" + + @tool + def with_parameters(x: int, y: str) -> dict: + """A tool that does nothing.""" + return {"x": x, "y": y} + + @tool + def with_parameters_and_callbacks(x: int, y: str, callbacks: Callbacks) -> dict: + """A tool that does nothing.""" + return {"x": x, "y": y} + + # type ignores below because the tools don't appear to be runnables to type checkers + # we can remove as soon as that's fixed + events = await _collect_events(parameterless.astream_events({}, version="v2")) # type: ignore + assert events == [ + { + "data": {"input": {}}, + "event": "on_tool_start", + "metadata": {}, + "name": "parameterless", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "hello"}, + "event": "on_tool_end", + "metadata": {}, + "name": "parameterless", + "run_id": "", + "tags": [], + }, + ] + events = await _collect_events(with_callbacks.astream_events({}, version="v2")) # type: ignore + assert events == [ + { + "data": {"input": {}}, + "event": "on_tool_start", + "metadata": {}, + "name": "with_callbacks", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "world"}, + "event": "on_tool_end", + "metadata": {}, + "name": "with_callbacks", + "run_id": "", + "tags": [], + }, + ] + events = await _collect_events( + with_parameters.astream_events({"x": 1, "y": "2"}, version="v2") # type: ignore + ) + assert events == [ + { + "data": {"input": {"x": 1, "y": "2"}}, + "event": "on_tool_start", + "metadata": {}, + "name": "with_parameters", + "run_id": "", + "tags": [], + }, + { + "data": {"output": {"x": 1, "y": "2"}}, + "event": "on_tool_end", + "metadata": {}, + "name": "with_parameters", + "run_id": "", + "tags": [], + }, + ] + + events = await _collect_events( + with_parameters_and_callbacks.astream_events({"x": 1, "y": "2"}, version="v2") # type: ignore + ) + assert events == [ + { + "data": {"input": {"x": 1, "y": "2"}}, + "event": "on_tool_start", + "metadata": {}, + "name": "with_parameters_and_callbacks", + "run_id": "", + "tags": [], + }, + { + "data": {"output": {"x": 1, "y": "2"}}, + "event": "on_tool_end", + "metadata": {}, + "name": "with_parameters_and_callbacks", + "run_id": "", + "tags": [], + }, + ] + + +class HardCodedRetriever(BaseRetriever): + documents: List[Document] + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + return self.documents + + +async def test_event_stream_with_retriever() -> None: + """Test the event stream with a retriever.""" + retriever = HardCodedRetriever( + documents=[ + Document( + page_content="hello world!", + metadata={"foo": "bar"}, + ), + Document( + page_content="goodbye world!", + metadata={"food": "spare"}, + ), + ] + ) + events = await _collect_events( + retriever.astream_events({"query": "hello"}, version="v2") + ) + assert events == [ + { + "data": { + "input": {"query": "hello"}, + }, + "event": "on_retriever_start", + "metadata": {}, + "name": "HardCodedRetriever", + "run_id": "", + "tags": [], + }, + { + "data": { + "output": [ + Document(page_content="hello world!", metadata={"foo": "bar"}), + Document(page_content="goodbye world!", metadata={"food": "spare"}), + ] + }, + "event": "on_retriever_end", + "metadata": {}, + "name": "HardCodedRetriever", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_retriever_and_formatter() -> None: + """Test the event stream with a retriever.""" + retriever = HardCodedRetriever( + documents=[ + Document( + page_content="hello world!", + metadata={"foo": "bar"}, + ), + Document( + page_content="goodbye world!", + metadata={"food": "spare"}, + ), + ] + ) + + def format_docs(docs: List[Document]) -> str: + """Format the docs.""" + return ", ".join([doc.page_content for doc in docs]) + + chain = retriever | format_docs + events = await _collect_events(chain.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"query": "hello"}}, + "event": "on_retriever_start", + "metadata": {}, + "name": "HardCodedRetriever", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": { + "input": {"query": "hello"}, + "output": [ + Document(page_content="hello world!", metadata={"foo": "bar"}), + Document(page_content="goodbye world!", metadata={"food": "spare"}), + ], + }, + "event": "on_retriever_end", + "metadata": {}, + "name": "HardCodedRetriever", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "format_docs", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "hello world!, goodbye world!"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "format_docs", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "hello world!, goodbye world!"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": { + "input": [ + Document(page_content="hello world!", metadata={"foo": "bar"}), + Document(page_content="goodbye world!", metadata={"food": "spare"}), + ], + "output": "hello world!, goodbye world!", + }, + "event": "on_chain_end", + "metadata": {}, + "name": "format_docs", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"output": "hello world!, goodbye world!"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_on_chain_with_tool() -> None: + """Test the event stream with a tool.""" + + @tool + def concat(a: str, b: str) -> str: + """A tool that does nothing.""" + return a + b + + def reverse(s: str) -> str: + """Reverse a string.""" + return s[::-1] + + # For whatever reason type annotations fail here because reverse + # does not appear to be a runnable + chain = concat | reverse # type: ignore + + events = await _collect_events( + chain.astream_events({"a": "hello", "b": "world"}, version="v2") + ) + assert events == [ + { + "data": {"input": {"a": "hello", "b": "world"}}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"a": "hello", "b": "world"}}, + "event": "on_tool_start", + "metadata": {}, + "name": "concat", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"input": {"a": "hello", "b": "world"}, "output": "helloworld"}, + "event": "on_tool_end", + "metadata": {}, + "name": "concat", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "dlrowolleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "dlrowolleh"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": "helloworld", "output": "dlrowolleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "reverse", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"output": "dlrowolleh"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +@pytest.mark.xfail(reason="Fix order of callback invocations in RunnableSequence") +async def test_chain_ordering() -> None: + """Test the event stream with a tool.""" + + def foo(a: str) -> str: + return a + + def bar(a: str) -> str: + return a + + chain = RunnableLambda(foo) | RunnableLambda(bar) + iterable = chain.astream_events("q", version="v2") + + events = [] + + for _ in range(10): + try: + next_chunk = await iterable.__anext__() + events.append(next_chunk) + except Exception: + break + + events = _with_nulled_run_id(events) + for event in events: + event["tags"] = sorted(event["tags"]) + + assert events == [ + { + "data": {"input": "q"}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "q"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"input": "q", "output": "q"}, + "event": "on_chain_end", + "metadata": {}, + "name": "foo", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "bar", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "q"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "bar", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "q"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": "q", "output": "q"}, + "event": "on_chain_end", + "metadata": {}, + "name": "bar", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"output": "q"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +async def test_event_stream_with_retry() -> None: + """Test the event stream with a tool.""" + + def success(inputs: str) -> str: + return "success" + + def fail(inputs: str) -> None: + """Simple func.""" + raise Exception("fail") + + chain = RunnableLambda(success) | RunnableLambda(fail).with_retry( + stop_after_attempt=1, + ) + iterable = chain.astream_events("q", version="v2") + + events = [] + + for _ in range(10): + try: + next_chunk = await iterable.__anext__() + events.append(next_chunk) + except Exception: + break + + events = _with_nulled_run_id(events) + for event in events: + event["tags"] = sorted(event["tags"]) + + assert events == [ + { + "data": {"input": "q"}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "success", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {"chunk": "success"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "success", + "run_id": "", + "tags": ["seq:step:1"], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "fail", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"input": "q", "output": "success"}, + "event": "on_chain_end", + "metadata": {}, + "name": "success", + "run_id": "", + "tags": ["seq:step:1"], + }, + ] + + +async def test_with_llm() -> None: + """Test with regular llm.""" + prompt = ChatPromptTemplate.from_messages( + [("system", "You are Cat Agent 007"), ("human", "{question}")] + ).with_config({"run_name": "my_template", "tags": ["my_template"]}) + llm = FakeStreamingListLLM(responses=["abc"]) + + chain = prompt | llm + events = await _collect_events( + chain.astream_events({"question": "hello"}, version="v2") + ) + assert events == [ + { + "data": {"input": {"question": "hello"}}, + "event": "on_chain_start", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"input": {"question": "hello"}}, + "event": "on_prompt_start", + "metadata": {}, + "name": "my_template", + "run_id": "", + "tags": ["my_template", "seq:step:1"], + }, + { + "data": { + "input": {"question": "hello"}, + "output": ChatPromptValue( + messages=[ + SystemMessage(content="You are Cat Agent 007"), + HumanMessage(content="hello"), + ] + ), + }, + "event": "on_prompt_end", + "metadata": {}, + "name": "my_template", + "run_id": "", + "tags": ["my_template", "seq:step:1"], + }, + { + "data": { + "input": {"prompts": ["System: You are Cat Agent 007\n" "Human: hello"]} + }, + "event": "on_llm_start", + "metadata": {}, + "name": "FakeStreamingListLLM", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": { + "input": { + "prompts": ["System: You are Cat Agent 007\n" "Human: hello"] + }, + "output": { + "generations": [ + [{"generation_info": None, "text": "abc", "type": "Generation"}] + ], + "llm_output": None, + }, + }, + "event": "on_llm_end", + "metadata": {}, + "name": "FakeStreamingListLLM", + "run_id": "", + "tags": ["seq:step:2"], + }, + { + "data": {"chunk": "a"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "b"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": "c"}, + "event": "on_chain_stream", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + { + "data": {"output": "abc"}, + "event": "on_chain_end", + "metadata": {}, + "name": "RunnableSequence", + "run_id": "", + "tags": [], + }, + ] + + +async def test_runnable_each() -> None: + """Test runnable each astream_events.""" + + async def add_one(x: int) -> int: + return x + 1 + + add_one_map = RunnableLambda(add_one).map() # type: ignore + assert await add_one_map.ainvoke([1, 2, 3]) == [2, 3, 4] + + with pytest.raises(NotImplementedError): + async for _ in add_one_map.astream_events([1, 2, 3], version="v2"): + pass + + +async def test_events_astream_config() -> None: + """Test that astream events support accepting config""" + infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")]) + good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")]) + model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields( + messages=ConfigurableField( + id="messages", + name="Messages", + description="Messages return by the LLM", + ) + ) + + model_02 = model.with_config({"configurable": {"messages": good_world_on_repeat}}) + assert model_02.invoke("hello") == AIMessage(content="Goodbye world", id="ai2") + + events = await _collect_events(model_02.astream_events("hello", version="v2")) + assert events == [ + { + "data": {"input": "hello"}, + "event": "on_chat_model_start", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content="Goodbye", id="ai2")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content=" ", id="ai2")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": AIMessageChunk(content="world", id="ai2")}, + "event": "on_chat_model_stream", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + { + "data": { + "output": AIMessageChunk(content="Goodbye world", id="ai2"), + }, + "event": "on_chat_model_end", + "metadata": {}, + "name": "GenericFakeChatModel", + "run_id": "", + "tags": [], + }, + ] + + +async def test_runnable_with_message_history() -> None: + class InMemoryHistory(BaseChatMessageHistory, BaseModel): + """In memory implementation of chat message history.""" + + # Attention: for the tests use an Any type to work-around a pydantic issue + # where it re-instantiates a list, so mutating the list doesn't end up mutating + # the content in the store! + + # Using Any type here rather than List[BaseMessage] due to pydantic issue! + messages: Any + + def add_message(self, message: BaseMessage) -> None: + """Add a self-created message to the store.""" + self.messages.append(message) + + def clear(self) -> None: + self.messages = [] + + # Here we use a global variable to store the chat message history. + # This will make it easier to inspect it to see the underlying results. + store: Dict = {} + + def get_by_session_id(session_id: str) -> BaseChatMessageHistory: + """Get a chat message history""" + if session_id not in store: + store[session_id] = [] + return InMemoryHistory(messages=store[session_id]) + + infinite_cycle = cycle( + [AIMessage(content="hello", id="ai3"), AIMessage(content="world", id="ai4")] + ) + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are a cat"), + MessagesPlaceholder(variable_name="history"), + ("human", "{question}"), + ] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + + chain: Runnable = prompt | model + with_message_history = RunnableWithMessageHistory( + chain, + get_session_history=get_by_session_id, + input_messages_key="question", + history_messages_key="history", + ) + await with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).ainvoke({"question": "hello"}) + + assert store == { + "session-123": [ + HumanMessage(content="hello"), + AIMessage(content="hello", id="ai3"), + ] + } + + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke({"question": "meow"}) + assert store == { + "session-123": [ + HumanMessage(content="hello"), + AIMessage(content="hello", id="ai3"), + HumanMessage(content="meow"), + AIMessage(content="world", id="ai4"), + ] + } + + +EXPECTED_EVENTS = [ + { + "data": {"input": 1}, + "event": "on_chain_start", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, + { + "data": {}, + "event": "on_chain_start", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": 2}, + "event": "on_chain_stream", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"input": 1, "output": 2}, + "event": "on_chain_end", + "metadata": {}, + "name": "add_one", + "run_id": "", + "tags": [], + }, + { + "data": {"chunk": 2}, + "event": "on_chain_stream", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, + { + "data": {"output": 2}, + "event": "on_chain_end", + "metadata": {}, + "name": "add_one_proxy", + "run_id": "", + "tags": [], + }, +] + + +@pytest.mark.xfail( + reason="This test is failing due to missing functionality." + "Need to implement logic in _transform_stream_with_config that mimics the async " + "variant that uses tap_output_iter" +) +async def test_sync_in_async_stream_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + def add_one_(x: int) -> int: + return x + 1 + + add_one = RunnableLambda(add_one_) + + async def add_one_proxy_(x: int, config: RunnableConfig) -> int: + streaming = add_one.stream(x, config) + results = [result for result in streaming] + return results[0] + + add_one_proxy = RunnableLambda(add_one_proxy_) # type: ignore + + events = await _collect_events(add_one_proxy.astream_events(1, version="v2")) + assert events == EXPECTED_EVENTS + + +async def test_async_in_async_stream_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + async def add_one(x: int) -> int: + return x + 1 + + add_one_ = RunnableLambda(add_one) # type: ignore + + async def add_one_proxy(x: int, config: RunnableConfig) -> int: + # Use sync streaming + streaming = add_one_.astream(x, config) + results = [result async for result in streaming] + return results[0] + + add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore + + events = await _collect_events(add_one_proxy_.astream_events(1, version="v2")) + assert events == EXPECTED_EVENTS + + +@pytest.mark.xfail( + reason="This test is failing due to missing functionality." + "Need to implement logic in _transform_stream_with_config that mimics the async " + "variant that uses tap_output_iter" +) +async def test_sync_in_sync_lambdas() -> None: + """Test invoking nested runnable lambda.""" + + def add_one(x: int) -> int: + return x + 1 + + add_one_ = RunnableLambda(add_one) + + def add_one_proxy(x: int, config: RunnableConfig) -> int: + # Use sync streaming + streaming = add_one_.stream(x, config) + results = [result for result in streaming] + return results[0] + + add_one_proxy_ = RunnableLambda(add_one_proxy) + + events = await _collect_events(add_one_proxy_.astream_events(1, version="v2")) + assert events == EXPECTED_EVENTS