mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 03:31:51 +00:00
core[patch]: add on_tool_error to _AstreamEventsCallbackHandler
This commit is contained in:
parent
4f9f97bd12
commit
98780a98c6
@ -23,6 +23,11 @@ class EventData(TypedDict, total=False):
|
||||
won't be known until the *END* of the Runnable when it has finished streaming
|
||||
its inputs.
|
||||
"""
|
||||
error: NotRequired[BaseException]
|
||||
"""The error that occurred during the execution of the Runnable.
|
||||
|
||||
This field is only available if the Runnable raised an exception.
|
||||
"""
|
||||
output: Any
|
||||
"""The output of the Runnable that generated the event.
|
||||
|
||||
|
@ -627,6 +627,41 @@ class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHand
|
||||
"tool",
|
||||
)
|
||||
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
run_info = self.run_map.pop(run_id)
|
||||
if "inputs" not in run_info:
|
||||
msg = (
|
||||
f"Run ID {run_id} is a tool call and is expected to have "
|
||||
f"inputs associated with it."
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
inputs = run_info["inputs"]
|
||||
|
||||
self._send(
|
||||
{
|
||||
"event": "on_tool_error",
|
||||
"data": {
|
||||
"error": error,
|
||||
"input": inputs,
|
||||
},
|
||||
"run_id": str(run_id),
|
||||
"name": run_info["name"],
|
||||
"tags": run_info["tags"],
|
||||
"metadata": run_info["metadata"],
|
||||
"parent_ids": self._get_parent_ids(run_id),
|
||||
},
|
||||
"tool",
|
||||
)
|
||||
|
||||
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
"""End a trace for a tool run."""
|
||||
run_info = self.run_map.pop(run_id)
|
||||
|
Loading…
Reference in New Issue
Block a user