From 98780a98c6e316fb22435ff5ae0efbcb077a83fa Mon Sep 17 00:00:00 2001 From: vbarda Date: Mon, 7 Apr 2025 10:46:29 -0400 Subject: [PATCH] core[patch]: add on_tool_error to _AstreamEventsCallbackHandler --- libs/core/langchain_core/runnables/schema.py | 5 +++ .../langchain_core/tracers/event_stream.py | 35 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/libs/core/langchain_core/runnables/schema.py b/libs/core/langchain_core/runnables/schema.py index 20ad580070a..e49f3c74949 100644 --- a/libs/core/langchain_core/runnables/schema.py +++ b/libs/core/langchain_core/runnables/schema.py @@ -23,6 +23,11 @@ class EventData(TypedDict, total=False): won't be known until the *END* of the Runnable when it has finished streaming its inputs. """ + error: NotRequired[BaseException] + """The error that occurred during the execution of the Runnable. + + This field is only available if the Runnable raised an exception. + """ output: Any """The output of the Runnable that generated the event. diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 7106a281ebd..b50b93b709f 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -627,6 +627,41 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand "tool", ) + async def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + run_info = self.run_map.pop(run_id) + if "inputs" not in run_info: + msg = ( + f"Run ID {run_id} is a tool call and is expected to have " + f"inputs associated with it." + ) + raise AssertionError(msg) + inputs = run_info["inputs"] + + self._send( + { + "event": "on_tool_error", + "data": { + "error": error, + "input": inputs, + }, + "run_id": str(run_id), + "name": run_info["name"], + "tags": run_info["tags"], + "metadata": run_info["metadata"], + "parent_ids": self._get_parent_ids(run_id), + }, + "tool", + ) + async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for a tool run.""" run_info = self.run_map.pop(run_id)