mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
core[minor]: Add dispatching for custom events (#24080)
This PR allows dispatching adhoc events for a given run. # Context This PR allows users to send arbitrary data to the callback system and to the astream events API from within a given runnable. This can be extremely useful to surface custom information to end users about progress etc. Integration with langsmith tracer will be done separately since the data cannot be currently visualized. It'll be accommodated using the events attribute of the Run # Examples with astream events ```python from langchain_core.callbacks import adispatch_custom_event from langchain_core.tools import tool @tool async def foo(x: int) -> int: """Foo""" await adispatch_custom_event("event1", {"x": x}) await adispatch_custom_event("event2", {"x": x}) return x + 1 async for event in foo.astream_events({'x': 1}, version='v2'): print(event) ``` ```python {'event': 'on_tool_start', 'data': {'input': {'x': 1}}, 'name': 'foo', 'tags': [], 'run_id': 'fd6fb7a7-dd37-4191-962c-e43e245909f6', 'metadata': {}, 'parent_ids': []} {'event': 'on_custom_event', 'run_id': 'fd6fb7a7-dd37-4191-962c-e43e245909f6', 'name': 'event1', 'tags': [], 'metadata': {}, 'data': {'x': 1}, 'parent_ids': []} {'event': 'on_custom_event', 'run_id': 'fd6fb7a7-dd37-4191-962c-e43e245909f6', 'name': 'event2', 'tags': [], 'metadata': {}, 'data': {'x': 1}, 'parent_ids': []} {'event': 'on_tool_end', 'data': {'output': 2}, 'run_id': 'fd6fb7a7-dd37-4191-962c-e43e245909f6', 'name': 'foo', 'tags': [], 'metadata': {}, 'parent_ids': []} ``` ```python from langchain_core.callbacks import adispatch_custom_event from langchain_core.runnables import RunnableLambda @RunnableLambda async def foo(x: int) -> int: """Foo""" await adispatch_custom_event("event1", {"x": x}) await adispatch_custom_event("event2", {"x": x}) return x + 1 async for event in foo.astream_events(1, version='v2'): print(event) ``` ```python {'event': 'on_chain_start', 'data': {'input': 1}, 'name': 'foo', 'tags': [], 'run_id': 'ce2beef2-8608-49ea-8eba-537bdaafb8ec', 'metadata': {}, 'parent_ids': []} {'event': 'on_custom_event', 'run_id': 'ce2beef2-8608-49ea-8eba-537bdaafb8ec', 'name': 'event1', 'tags': [], 'metadata': {}, 'data': {'x': 1}, 'parent_ids': []} {'event': 'on_custom_event', 'run_id': 'ce2beef2-8608-49ea-8eba-537bdaafb8ec', 'name': 'event2', 'tags': [], 'metadata': {}, 'data': {'x': 1}, 'parent_ids': []} {'event': 'on_chain_stream', 'run_id': 'ce2beef2-8608-49ea-8eba-537bdaafb8ec', 'name': 'foo', 'tags': [], 'metadata': {}, 'data': {'chunk': 2}, 'parent_ids': []} {'event': 'on_chain_end', 'data': {'output': 2}, 'run_id': 'ce2beef2-8608-49ea-8eba-537bdaafb8ec', 'name': 'foo', 'tags': [], 'metadata': {}, 'parent_ids': []} ``` # Examples with handlers This is copy pasted from unit tests ```python class CustomCallbackManager(BaseCallbackHandler): def __init__(self) -> None: self.events: List[Any] = [] def on_custom_event( self, name: str, data: Any, *, run_id: UUID, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: assert kwargs == {} self.events.append( ( name, data, run_id, tags, metadata, ) ) callback = CustomCallbackManager() run_id = uuid.UUID(int=7) @RunnableLambda def foo(x: int, config: RunnableConfig) -> int: dispatch_custom_event("event1", {"x": x}) dispatch_custom_event("event2", {"x": x}, config=config) return x foo.invoke(1, {"callbacks": [callback], "run_id": run_id}) assert callback.events == [ ("event1", {"x": 1}, UUID("00000000-0000-0000-0000-000000000007"), [], {}), ("event2", {"x": 1}, UUID("00000000-0000-0000-0000-000000000007"), [], {}), ] ```
This commit is contained in:
parent
14a8bbc21a
commit
dc131ac42a
@ -38,11 +38,15 @@ from langchain_core.callbacks.manager import (
|
||||
CallbackManagerForToolRun,
|
||||
ParentRunManager,
|
||||
RunManager,
|
||||
adispatch_custom_event,
|
||||
dispatch_custom_event,
|
||||
)
|
||||
from langchain_core.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
__all__ = [
|
||||
"dispatch_custom_event",
|
||||
"adispatch_custom_event",
|
||||
"RetrieverManagerMixin",
|
||||
"LLMManagerMixin",
|
||||
"ChainManagerMixin",
|
||||
|
@ -370,6 +370,31 @@ class RunManagerMixin:
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
|
||||
def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Override to define a handler for a custom event.
|
||||
|
||||
Args:
|
||||
name: The name of the custom event.
|
||||
data: The data for the custom event. Format will match
|
||||
the format specified by the user.
|
||||
run_id: The ID of the run.
|
||||
tags: The tags associated with the custom event
|
||||
(includes inherited tags).
|
||||
metadata: The metadata associated with the custom event
|
||||
(includes inherited metadata).
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
|
||||
class BaseCallbackHandler(
|
||||
LLMManagerMixin,
|
||||
@ -417,6 +442,11 @@ class BaseCallbackHandler(
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def ignore_custom_event(self) -> bool:
|
||||
"""Ignore custom event."""
|
||||
return False
|
||||
|
||||
|
||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Async callback handler for LangChain."""
|
||||
@ -799,6 +829,31 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
kwargs (Any): Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Override to define a handler for a custom event.
|
||||
|
||||
Args:
|
||||
name: The name of the custom event.
|
||||
data: The data for the custom event. Format will match
|
||||
the format specified by the user.
|
||||
run_id: The ID of the run.
|
||||
tags: The tags associated with the custom event
|
||||
(includes inherited tags).
|
||||
metadata: The metadata associated with the custom event
|
||||
(includes inherited metadata).
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseCallbackManager")
|
||||
|
||||
|
@ -48,6 +48,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1494,6 +1495,46 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Dispatch an adhoc event to the handlers (async version).
|
||||
|
||||
This event should NOT be used in any internal LangChain code. The event
|
||||
is meant specifically for users of the library to dispatch custom
|
||||
events that are tailored to their application.
|
||||
|
||||
Args:
|
||||
name: The name of the adhoc event.
|
||||
data: The data for the adhoc event.
|
||||
run_id: The ID of the run. Defaults to None.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
if kwargs:
|
||||
raise ValueError(
|
||||
"The dispatcher API does not accept additional keyword arguments."
|
||||
"Please do not pass any additional keyword arguments, instead "
|
||||
"include them in the data field."
|
||||
)
|
||||
if run_id is None:
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
handle_event(
|
||||
self.handlers,
|
||||
"on_custom_event",
|
||||
"ignore_custom_event",
|
||||
name,
|
||||
data,
|
||||
run_id=run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def configure(
|
||||
cls,
|
||||
@ -1833,6 +1874,46 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
async def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Dispatch an adhoc event to the handlers (async version).
|
||||
|
||||
This event should NOT be used in any internal LangChain code. The event
|
||||
is meant specifically for users of the library to dispatch custom
|
||||
events that are tailored to their application.
|
||||
|
||||
Args:
|
||||
name: The name of the adhoc event.
|
||||
data: The data for the adhoc event.
|
||||
run_id: The ID of the run. Defaults to None.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
if run_id is None:
|
||||
run_id = uuid.uuid4()
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(
|
||||
"The dispatcher API does not accept additional keyword arguments."
|
||||
"Please do not pass any additional keyword arguments, instead "
|
||||
"include them in the data field."
|
||||
)
|
||||
await ahandle_event(
|
||||
self.handlers,
|
||||
"on_custom_event",
|
||||
"ignore_custom_event",
|
||||
name,
|
||||
data,
|
||||
run_id=run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
async def on_retriever_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
@ -2169,3 +2250,189 @@ def _configure(
|
||||
):
|
||||
callback_manager.add_handler(var_handler, inheritable)
|
||||
return callback_manager
|
||||
|
||||
|
||||
async def adispatch_custom_event(
|
||||
name: str, data: Any, *, config: Optional[RunnableConfig] = None
|
||||
) -> None:
|
||||
"""Dispatch an adhoc event to the handlers.
|
||||
|
||||
Args:
|
||||
name: The name of the adhoc event.
|
||||
data: The data for the adhoc event. Free form data. Ideally should be
|
||||
JSON serializable to avoid serialization issues downstream, but
|
||||
this is not enforced.
|
||||
config: Optional config object. Mirrors the async API but not strictly needed.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackHandler,
|
||||
adispatch_custom_event
|
||||
)
|
||||
from langchain_core.runnable import RunnableLambda
|
||||
|
||||
class CustomCallbackManager(AsyncCallbackHandler):
|
||||
async def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print(f"Received custom event: {name} with data: {data}")
|
||||
|
||||
callback = CustomCallbackManager()
|
||||
|
||||
async def foo(inputs):
|
||||
await adispatch_custom_event("my_event", {"bar": "buzz})
|
||||
return inputs
|
||||
|
||||
foo_ = RunnableLambda(foo)
|
||||
await foo_.ainvoke({"a": "1"}, {"callbacks": [CustomCallbackManager()]})
|
||||
|
||||
Example: Use with astream events
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackHandler,
|
||||
adispatch_custom_event
|
||||
)
|
||||
from langchain_core.runnable import RunnableLambda
|
||||
|
||||
class CustomCallbackManager(AsyncCallbackHandler):
|
||||
async def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print(f"Received custom event: {name} with data: {data}")
|
||||
|
||||
callback = CustomCallbackManager()
|
||||
|
||||
async def foo(inputs):
|
||||
await adispatch_custom_event("event_type_1", {"bar": "buzz})
|
||||
await adispatch_custom_event("event_type_2", 5)
|
||||
return inputs
|
||||
|
||||
foo_ = RunnableLambda(foo)
|
||||
|
||||
async for event in foo_.ainvoke_stream(
|
||||
{"a": "1"},
|
||||
version="v2",
|
||||
config={"callbacks": [CustomCallbackManager()]}
|
||||
):
|
||||
print(event)
|
||||
|
||||
.. warning: If using python <= 3.10 and async, you MUST
|
||||
specify the `config` parameter or the function will raise an error.
|
||||
This is due to a limitation in asyncio for python <= 3.10 that prevents
|
||||
LangChain from automatically propagating the config object on the user's
|
||||
behalf.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
from langchain_core.runnables.config import (
|
||||
ensure_config,
|
||||
get_async_callback_manager_for_config,
|
||||
)
|
||||
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# We want to get the callback manager for the parent run.
|
||||
# This is a work-around for now to be able to dispatch adhoc events from
|
||||
# within a tool or a lambda and have the metadata events associated
|
||||
# with the parent run rather than have a new run id generated for each.
|
||||
if callback_manager.parent_run_id is None:
|
||||
raise RuntimeError(
|
||||
"Unable to dispatch an adhoc event without a parent run id."
|
||||
"This function can only be called from within an existing run (e.g.,"
|
||||
"inside a tool or a RunnableLambda or a RunnableGenerator.)"
|
||||
"If you are doing that and still seeing this error, try explicitly"
|
||||
"passing the config parameter to this function."
|
||||
)
|
||||
|
||||
await callback_manager.on_custom_event(
|
||||
name,
|
||||
data,
|
||||
run_id=callback_manager.parent_run_id,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_custom_event(
|
||||
name: str, data: Any, *, config: Optional[RunnableConfig] = None
|
||||
) -> None:
|
||||
"""Dispatch an adhoc event.
|
||||
|
||||
Args:
|
||||
name: The name of the adhoc event.
|
||||
data: The data for the adhoc event. Free form data. Ideally should be
|
||||
JSON serializable to avoid serialization issues downstream, but
|
||||
this is not enforced.
|
||||
config: Optional config object. Mirrors the async API but not strictly needed.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.callbacks import dispatch_custom_event
|
||||
from langchain_core.runnable import RunnableLambda
|
||||
|
||||
class CustomCallbackManager(BaseCallbackHandler):
|
||||
def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print(f"Received custom event: {name} with data: {data}")
|
||||
|
||||
def foo(inputs):
|
||||
dispatch_custom_event("my_event", {"bar": "buzz})
|
||||
return inputs
|
||||
|
||||
foo_ = RunnableLambda(foo)
|
||||
foo_.invoke({"a": "1"}, {"callbacks": [CustomCallbackManager()]})
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
from langchain_core.runnables.config import (
|
||||
ensure_config,
|
||||
get_callback_manager_for_config,
|
||||
)
|
||||
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# We want to get the callback manager for the parent run.
|
||||
# This is a work-around for now to be able to dispatch adhoc events from
|
||||
# within a tool or a lambda and have the metadata events associated
|
||||
# with the parent run rather than have a new run id generated for each.
|
||||
if callback_manager.parent_run_id is None:
|
||||
raise RuntimeError(
|
||||
"Unable to dispatch an adhoc event without a parent run id."
|
||||
"This function can only be called from within an existing run (e.g.,"
|
||||
"inside a tool or a RunnableLambda or a RunnableGenerator.)"
|
||||
"If you are doing that and still seeing this error, try explicitly"
|
||||
"passing the config parameter to this function."
|
||||
)
|
||||
callback_manager.on_custom_event(
|
||||
name,
|
||||
data,
|
||||
run_id=callback_manager.parent_run_id,
|
||||
)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Sequence
|
||||
from typing import Any, Dict, List, Literal, Sequence, Union
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
@ -37,7 +37,7 @@ class EventData(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
|
||||
class StreamEvent(TypedDict):
|
||||
class BaseStreamEvent(TypedDict):
|
||||
"""Streaming event.
|
||||
|
||||
Schema of a streaming event which is produced from the astream_events method.
|
||||
@ -101,8 +101,6 @@ class StreamEvent(TypedDict):
|
||||
|
||||
Please see the documentation for `EventData` for more details.
|
||||
"""
|
||||
name: str
|
||||
"""The name of the runnable that generated the event."""
|
||||
run_id: str
|
||||
"""An randomly generated ID to keep track of the execution of the given runnable.
|
||||
|
||||
@ -128,11 +126,6 @@ class StreamEvent(TypedDict):
|
||||
|
||||
`.astream_events(..., {"metadata": {"foo": "bar"}})`.
|
||||
"""
|
||||
data: EventData
|
||||
"""Event data.
|
||||
|
||||
The contents of the event data depend on the event type.
|
||||
"""
|
||||
|
||||
parent_ids: Sequence[str]
|
||||
"""A list of the parent IDs associated with this event.
|
||||
@ -146,3 +139,34 @@ class StreamEvent(TypedDict):
|
||||
|
||||
Only supported as of v2 of the astream events API. v1 will return an empty list.
|
||||
"""
|
||||
|
||||
|
||||
class StandardStreamEvent(BaseStreamEvent):
|
||||
"""A standard stream event that follows LangChain convention for event data."""
|
||||
|
||||
data: EventData
|
||||
"""Event data.
|
||||
|
||||
The contents of the event data depend on the event type.
|
||||
"""
|
||||
name: str
|
||||
"""The name of the runnable that generated the event."""
|
||||
|
||||
|
||||
class CustomStreamEvent(BaseStreamEvent):
|
||||
"""A custom stream event created by the user.
|
||||
|
||||
.. versionadded:: 0.2.14
|
||||
"""
|
||||
|
||||
# Overwrite the event field to be more specific.
|
||||
event: Literal["on_custom_event"] # type: ignore[misc]
|
||||
|
||||
"""The event type."""
|
||||
name: str
|
||||
"""A user defined name for the event."""
|
||||
data: Any
|
||||
"""The data associated with the event. Free form and can be anything."""
|
||||
|
||||
|
||||
StreamEvent = Union[StandardStreamEvent, CustomStreamEvent]
|
||||
|
@ -28,7 +28,12 @@ from langchain_core.outputs import (
|
||||
GenerationChunk,
|
||||
LLMResult,
|
||||
)
|
||||
from langchain_core.runnables.schema import EventData, StreamEvent
|
||||
from langchain_core.runnables.schema import (
|
||||
CustomStreamEvent,
|
||||
EventData,
|
||||
StandardStreamEvent,
|
||||
StreamEvent,
|
||||
)
|
||||
from langchain_core.runnables.utils import (
|
||||
Input,
|
||||
Output,
|
||||
@ -161,7 +166,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
return
|
||||
if tap is sentinel:
|
||||
# if we are the first to tap, issue stream events
|
||||
event: StreamEvent = {
|
||||
event: StandardStreamEvent = {
|
||||
"event": f"on_{run_info['run_type']}_stream",
|
||||
"run_id": str(run_id),
|
||||
"name": run_info["name"],
|
||||
@ -203,7 +208,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
return
|
||||
if tap is sentinel:
|
||||
# if we are the first to tap, issue stream events
|
||||
event: StreamEvent = {
|
||||
event: StandardStreamEvent = {
|
||||
"event": f"on_{run_info['run_type']}_stream",
|
||||
"run_id": str(run_id),
|
||||
"name": run_info["name"],
|
||||
@ -341,6 +346,28 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
run_type,
|
||||
)
|
||||
|
||||
async def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Generate a custom astream event."""
|
||||
event = CustomStreamEvent(
|
||||
event="on_custom_event",
|
||||
run_id=str(run_id),
|
||||
name=name,
|
||||
tags=tags or [],
|
||||
metadata=metadata or {},
|
||||
data=data,
|
||||
parent_ids=self._get_parent_ids(run_id),
|
||||
)
|
||||
self._send(event, name)
|
||||
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
@ -678,7 +705,7 @@ async def _astream_events_implementation_v1(
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
) -> AsyncIterator[StandardStreamEvent]:
|
||||
from langchain_core.runnables import ensure_config
|
||||
from langchain_core.runnables.utils import _RootEventFilter
|
||||
from langchain_core.tracers.log_stream import (
|
||||
@ -733,7 +760,7 @@ async def _astream_events_implementation_v1(
|
||||
encountered_start_event = True
|
||||
state = run_log.state.copy()
|
||||
|
||||
event = StreamEvent(
|
||||
event = StandardStreamEvent(
|
||||
event=f"on_{state['type']}_start",
|
||||
run_id=state["id"],
|
||||
name=root_name,
|
||||
@ -798,7 +825,7 @@ async def _astream_events_implementation_v1(
|
||||
# And this avoids duplicates as well!
|
||||
log_entry["streamed_output"] = []
|
||||
|
||||
yield StreamEvent(
|
||||
yield StandardStreamEvent(
|
||||
event=f"on_{log_entry['type']}_{event_type}",
|
||||
name=log_entry["name"],
|
||||
run_id=log_entry["id"],
|
||||
@ -824,7 +851,7 @@ async def _astream_events_implementation_v1(
|
||||
# Clean up the stream, we don't need it anymore.
|
||||
state["streamed_output"] = []
|
||||
|
||||
event = StreamEvent(
|
||||
event = StandardStreamEvent(
|
||||
event=f"on_{state['type']}_stream",
|
||||
run_id=state["id"],
|
||||
tags=root_tags,
|
||||
@ -839,7 +866,7 @@ async def _astream_events_implementation_v1(
|
||||
state = run_log.state
|
||||
|
||||
# Finally yield the end event for the root runnable.
|
||||
event = StreamEvent(
|
||||
event = StandardStreamEvent(
|
||||
event=f"on_{state['type']}_end",
|
||||
name=root_name,
|
||||
run_id=state["id"],
|
||||
@ -866,7 +893,7 @@ async def _astream_events_implementation_v2(
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
) -> AsyncIterator[StandardStreamEvent]:
|
||||
"""Implementation of the astream events API for V2 runnables."""
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
from langchain_core.runnables import ensure_config
|
||||
|
@ -0,0 +1,161 @@
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.callbacks.manager import (
|
||||
adispatch_custom_event,
|
||||
dispatch_custom_event,
|
||||
)
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
class AsyncCustomCallbackHandler(AsyncCallbackHandler):
|
||||
def __init__(self) -> None:
|
||||
self.events: List[Any] = []
|
||||
|
||||
async def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
assert kwargs == {}
|
||||
self.events.append(
|
||||
(
|
||||
name,
|
||||
data,
|
||||
run_id,
|
||||
tags,
|
||||
metadata,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_custom_event_root_dispatch() -> None:
|
||||
"""Test adhoc event in a nested chain."""
|
||||
# This just tests that nothing breaks on the path.
|
||||
# It shouldn't do anything at the moment, since the tracer isn't configured
|
||||
# to handle adhoc events.
|
||||
# Expected behavior is that the event cannot be dispatched
|
||||
with pytest.raises(RuntimeError):
|
||||
dispatch_custom_event("event1", {"x": 1})
|
||||
|
||||
|
||||
async def test_async_custom_event_root_dispatch() -> None:
|
||||
"""Test adhoc event in a nested chain."""
|
||||
# This just tests that nothing breaks on the path.
|
||||
# It shouldn't do anything at the moment, since the tracer isn't configured
|
||||
# to handle adhoc events.
|
||||
# Expected behavior is that the event cannot be dispatched
|
||||
with pytest.raises(RuntimeError):
|
||||
await adispatch_custom_event("event1", {"x": 1})
|
||||
|
||||
|
||||
IS_GTE_3_11 = sys.version_info >= (3, 11)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_GTE_3_11, reason="Requires Python >=3.11")
|
||||
async def test_async_custom_event_implicit_config() -> None:
|
||||
"""Test dispatch without passing config explicitly."""
|
||||
callback = AsyncCustomCallbackHandler()
|
||||
|
||||
run_id = uuid.UUID(int=7)
|
||||
|
||||
# Typing not working well with RunnableLambda when used as
|
||||
# a decorator for async functions
|
||||
@RunnableLambda # type: ignore[arg-type]
|
||||
async def foo(x: int, config: RunnableConfig) -> int:
|
||||
await adispatch_custom_event("event1", {"x": x})
|
||||
await adispatch_custom_event("event2", {"x": x})
|
||||
return x
|
||||
|
||||
await foo.ainvoke(
|
||||
1, # type: ignore[arg-type]
|
||||
{"callbacks": [callback], "run_id": run_id},
|
||||
)
|
||||
|
||||
assert callback.events == [
|
||||
("event1", {"x": 1}, UUID("00000000-0000-0000-0000-000000000007"), [], {}),
|
||||
("event2", {"x": 1}, UUID("00000000-0000-0000-0000-000000000007"), [], {}),
|
||||
]
|
||||
|
||||
|
||||
async def test_async_callback_manager() -> None:
|
||||
"""Test async callback manager."""
|
||||
|
||||
callback = AsyncCustomCallbackHandler()
|
||||
|
||||
run_id = uuid.UUID(int=7)
|
||||
|
||||
# Typing not working well with RunnableLambda when used as
|
||||
# a decorator for async functions
|
||||
@RunnableLambda # type: ignore[arg-type]
|
||||
async def foo(x: int, config: RunnableConfig) -> int:
|
||||
await adispatch_custom_event("event1", {"x": x}, config=config)
|
||||
await adispatch_custom_event("event2", {"x": x}, config=config)
|
||||
return x
|
||||
|
||||
await foo.ainvoke(
|
||||
1, # type: ignore[arg-type]
|
||||
{"callbacks": [callback], "run_id": run_id},
|
||||
)
|
||||
|
||||
assert callback.events == [
|
||||
("event1", {"x": 1}, UUID("00000000-0000-0000-0000-000000000007"), [], {}),
|
||||
("event2", {"x": 1}, UUID("00000000-0000-0000-0000-000000000007"), [], {}),
|
||||
]
|
||||
|
||||
|
||||
def test_sync_callback_manager() -> None:
|
||||
"""Test async callback manager."""
|
||||
|
||||
class CustomCallbackManager(BaseCallbackHandler):
|
||||
def __init__(self) -> None:
|
||||
self.events: List[Any] = []
|
||||
|
||||
def on_custom_event(
|
||||
self,
|
||||
name: str,
|
||||
data: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
assert kwargs == {}
|
||||
self.events.append(
|
||||
(
|
||||
name,
|
||||
data,
|
||||
run_id,
|
||||
tags,
|
||||
metadata,
|
||||
)
|
||||
)
|
||||
|
||||
callback = CustomCallbackManager()
|
||||
|
||||
run_id = uuid.UUID(int=7)
|
||||
|
||||
@RunnableLambda
|
||||
def foo(x: int, config: RunnableConfig) -> int:
|
||||
dispatch_custom_event("event1", {"x": x})
|
||||
dispatch_custom_event("event2", {"x": x}, config=config)
|
||||
return x
|
||||
|
||||
foo.invoke(1, {"callbacks": [callback], "run_id": run_id})
|
||||
|
||||
assert callback.events == [
|
||||
("event1", {"x": 1}, UUID("00000000-0000-0000-0000-000000000007"), [], {}),
|
||||
("event2", {"x": 1}, UUID("00000000-0000-0000-0000-000000000007"), [], {}),
|
||||
]
|
@ -31,6 +31,8 @@ EXPECTED_ALL = [
|
||||
"StdOutCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
"FileCallbackHandler",
|
||||
"adispatch_custom_event",
|
||||
"dispatch_custom_event",
|
||||
]
|
||||
|
||||
|
||||
|
@ -2353,3 +2353,245 @@ async def test_cancel_astream_events() -> None:
|
||||
|
||||
# node "anotherwhile" should never start
|
||||
assert anotherwhile.started is False
|
||||
|
||||
|
||||
async def test_custom_event() -> None:
|
||||
"""Test adhoc event."""
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
|
||||
# Ignoring type due to RunnableLamdba being dynamic when it comes to being
|
||||
# applied as a decorator to async functions.
|
||||
@RunnableLambda # type: ignore[arg-type]
|
||||
async def foo(x: int, config: RunnableConfig) -> int:
|
||||
"""Simple function that emits some adhoc events."""
|
||||
await adispatch_custom_event("event1", {"x": x}, config=config)
|
||||
await adispatch_custom_event("event2", "foo", config=config)
|
||||
return x + 1
|
||||
|
||||
uuid1 = uuid.UUID(int=7)
|
||||
|
||||
events = await _collect_events(
|
||||
foo.astream_events(
|
||||
1,
|
||||
version="v2",
|
||||
config={"run_id": uuid1},
|
||||
),
|
||||
with_nulled_ids=False,
|
||||
)
|
||||
|
||||
run_id = str(uuid1)
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": 1},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "foo",
|
||||
"parent_ids": [],
|
||||
"run_id": run_id,
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"x": 1},
|
||||
"event": "on_custom_event",
|
||||
"metadata": {},
|
||||
"name": "event1",
|
||||
"parent_ids": [],
|
||||
"run_id": run_id,
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": "foo",
|
||||
"event": "on_custom_event",
|
||||
"metadata": {},
|
||||
"name": "event2",
|
||||
"parent_ids": [],
|
||||
"run_id": run_id,
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": 2},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "foo",
|
||||
"parent_ids": [],
|
||||
"run_id": run_id,
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": 2},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "foo",
|
||||
"parent_ids": [],
|
||||
"run_id": run_id,
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def test_custom_event_nested() -> None:
|
||||
"""Test adhoc event in a nested chain."""
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
|
||||
# Ignoring type due to RunnableLamdba being dynamic when it comes to being
|
||||
# applied as a decorator to async functions.
|
||||
@RunnableLambda # type: ignore[arg-type]
|
||||
async def foo(x: int, config: RunnableConfig) -> int:
|
||||
"""Simple function that emits some adhoc events."""
|
||||
await adispatch_custom_event("event1", {"x": x}, config=config)
|
||||
await adispatch_custom_event("event2", "foo", config=config)
|
||||
return x + 1
|
||||
|
||||
run_id = uuid.UUID(int=7)
|
||||
child_run_id = uuid.UUID(int=8)
|
||||
|
||||
# Ignoring type due to RunnableLamdba being dynamic when it comes to being
|
||||
# applied as a decorator to async functions.
|
||||
@RunnableLambda # type: ignore[arg-type]
|
||||
async def bar(x: int, config: RunnableConfig) -> int:
|
||||
"""Simple function that emits some adhoc events."""
|
||||
return await foo.ainvoke(
|
||||
x, # type: ignore[arg-type]
|
||||
{"run_id": child_run_id, **config},
|
||||
)
|
||||
|
||||
events = await _collect_events(
|
||||
bar.astream_events(
|
||||
1,
|
||||
version="v2",
|
||||
config={"run_id": run_id},
|
||||
),
|
||||
with_nulled_ids=False,
|
||||
)
|
||||
|
||||
run_id = str(run_id) # type: ignore[assignment]
|
||||
child_run_id = str(child_run_id) # type: ignore[assignment]
|
||||
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": 1},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "bar",
|
||||
"parent_ids": [],
|
||||
"run_id": "00000000-0000-0000-0000-000000000007",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"input": 1},
|
||||
"event": "on_chain_start",
|
||||
"metadata": {},
|
||||
"name": "foo",
|
||||
"parent_ids": ["00000000-0000-0000-0000-000000000007"],
|
||||
"run_id": "00000000-0000-0000-0000-000000000008",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"x": 1},
|
||||
"event": "on_custom_event",
|
||||
"metadata": {},
|
||||
"name": "event1",
|
||||
"parent_ids": ["00000000-0000-0000-0000-000000000007"],
|
||||
"run_id": "00000000-0000-0000-0000-000000000008",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": "foo",
|
||||
"event": "on_custom_event",
|
||||
"metadata": {},
|
||||
"name": "event2",
|
||||
"parent_ids": ["00000000-0000-0000-0000-000000000007"],
|
||||
"run_id": "00000000-0000-0000-0000-000000000008",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"input": 1, "output": 2},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "foo",
|
||||
"parent_ids": ["00000000-0000-0000-0000-000000000007"],
|
||||
"run_id": "00000000-0000-0000-0000-000000000008",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"chunk": 2},
|
||||
"event": "on_chain_stream",
|
||||
"metadata": {},
|
||||
"name": "bar",
|
||||
"parent_ids": [],
|
||||
"run_id": "00000000-0000-0000-0000-000000000007",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": 2},
|
||||
"event": "on_chain_end",
|
||||
"metadata": {},
|
||||
"name": "bar",
|
||||
"parent_ids": [],
|
||||
"run_id": "00000000-0000-0000-0000-000000000007",
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def test_custom_event_root_dispatch() -> None:
|
||||
"""Test adhoc event in a nested chain."""
|
||||
# This just tests that nothing breaks on the path.
|
||||
# It shouldn't do anything at the moment, since the tracer isn't configured
|
||||
# to handle adhoc events.
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
|
||||
# Expected behavior is that the event cannot be dispatched
|
||||
with pytest.raises(RuntimeError):
|
||||
await adispatch_custom_event("event1", {"x": 1})
|
||||
|
||||
|
||||
IS_GTE_3_11 = sys.version_info >= (3, 11)
|
||||
|
||||
|
||||
# Test relies on automatically picking up RunnableConfig from contextvars
|
||||
@pytest.mark.skipif(not IS_GTE_3_11, reason="Requires Python >=3.11")
|
||||
async def test_custom_event_root_dispatch_with_in_tool() -> None:
|
||||
"""Test adhoc event in a nested chain."""
|
||||
from langchain_core.callbacks.manager import adispatch_custom_event
|
||||
from langchain_core.tools import tool
|
||||
|
||||
@tool
|
||||
async def foo(x: int) -> int:
|
||||
"""Foo"""
|
||||
await adispatch_custom_event("event1", {"x": x})
|
||||
return x + 1
|
||||
|
||||
# Ignoring type due to @tool not returning correct type annotations
|
||||
events = await _collect_events(
|
||||
foo.astream_events({"x": 2}, version="v2") # type: ignore[attr-defined]
|
||||
)
|
||||
assert events == [
|
||||
{
|
||||
"data": {"input": {"x": 2}},
|
||||
"event": "on_tool_start",
|
||||
"metadata": {},
|
||||
"name": "foo",
|
||||
"parent_ids": [],
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"x": 2},
|
||||
"event": "on_custom_event",
|
||||
"metadata": {},
|
||||
"name": "event1",
|
||||
"parent_ids": [],
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
{
|
||||
"data": {"output": 3},
|
||||
"event": "on_tool_end",
|
||||
"metadata": {},
|
||||
"name": "foo",
|
||||
"parent_ids": [],
|
||||
"run_id": "",
|
||||
"tags": [],
|
||||
},
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user