core[patch]: add on_tool_error to _AstreamEventsCallbackHandler

This commit is contained in:
vbarda 2025-04-07 10:46:29 -04:00
parent 4f9f97bd12
commit 98780a98c6
2 changed files with 40 additions and 0 deletions

View File

@ -23,6 +23,11 @@ class EventData(TypedDict, total=False):
won't be known until the *END* of the Runnable when it has finished streaming won't be known until the *END* of the Runnable when it has finished streaming
its inputs. its inputs.
""" """
error: NotRequired[BaseException]
"""The error that occurred during the execution of the Runnable.
This field is only available if the Runnable raised an exception.
"""
output: Any output: Any
"""The output of the Runnable that generated the event. """The output of the Runnable that generated the event.

View File

@ -627,6 +627,41 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"tool", "tool",
) )
async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
run_info = self.run_map.pop(run_id)
if "inputs" not in run_info:
msg = (
f"Run ID {run_id} is a tool call and is expected to have "
f"inputs associated with it."
)
raise AssertionError(msg)
inputs = run_info["inputs"]
self._send(
{
"event": "on_tool_error",
"data": {
"error": error,
"input": inputs,
},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"parent_ids": self._get_parent_ids(run_id),
},
"tool",
)
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run.""" """End a trace for a tool run."""
run_info = self.run_map.pop(run_id) run_info = self.run_map.pop(run_id)