fix(core): add on_tool_error to _AstreamEventsCallbackHandler (#30709)

Fixes https://github.com/langchain-ai/langchain/issues/30708

---------

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Vadym Barda
2025-09-10 21:34:05 -04:00
committed by GitHub
parent ced9fc270f
commit 8b1e25461b
2 changed files with 59 additions and 8 deletions

View File

@@ -23,6 +23,13 @@ class EventData(TypedDict, total=False):
won't be known until the *END* of the Runnable when it has finished streaming
its inputs.
"""
error: NotRequired[BaseException]
"""The error that occurred during the execution of the Runnable.
This field is only available if the Runnable raised an exception.
.. versionadded:: 1.0.0
"""
output: Any
"""The output of the Runnable that generated the event.

View File

@@ -610,6 +610,28 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
run_type,
)
def _get_tool_run_info_with_inputs(self, run_id: UUID) -> tuple[RunInfo, Any]:
"""Get run info for a tool and extract inputs, with validation.
Args:
run_id: The run ID of the tool.
Returns:
A tuple of (run_info, inputs).
Raises:
AssertionError: If the run ID is a tool call and does not have inputs.
"""
run_info = self.run_map.pop(run_id)
if "inputs" not in run_info:
msg = (
f"Run ID {run_id} is a tool call and is expected to have "
f"inputs associated with it."
)
raise AssertionError(msg)
inputs = run_info["inputs"]
return run_info, inputs
@override
async def on_tool_start(
self,
@@ -652,6 +674,35 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
"tool",
)
@override
async def on_tool_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
run_info, inputs = self._get_tool_run_info_with_inputs(run_id)
self._send(
{
"event": "on_tool_error",
"data": {
"error": error,
"input": inputs,
},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"parent_ids": self._get_parent_ids(run_id),
},
"tool",
)
@override
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run.
@@ -659,14 +710,7 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
Raises:
AssertionError: If the run ID is a tool call and does not have inputs
"""
run_info = self.run_map.pop(run_id)
if "inputs" not in run_info:
msg = (
f"Run ID {run_id} is a tool call and is expected to have "
f"inputs associated with it."
)
raise AssertionError(msg)
inputs = run_info["inputs"]
run_info, inputs = self._get_tool_run_info_with_inputs(run_id)
self._send(
{