This commit is contained in:
isaac hershenson
2024-09-11 11:29:06 -07:00
parent ee932c19d8
commit c50cd99d99

View File

@@ -239,6 +239,15 @@ class ToolException(Exception):
pass pass
def convert_exception_to_tool_exception(e: Exception) -> ToolException:
tool_e = ToolException()
tool_e.args = e.args
tool_e.__cause__ = e.__cause__
tool_e.__context__ = e.__context__
tool_e.__traceback__ = e.__traceback__
return tool_e
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]): class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
"""Interface LangChain tools must implement.""" """Interface LangChain tools must implement."""
@@ -575,6 +584,7 @@ class ChildTool(BaseTool):
if not self.handle_tool_error: if not self.handle_tool_error:
error_to_raise = e error_to_raise = e
else: else:
e = convert_exception_to_tool_exception(e)
content = _handle_tool_error(e, flag=self.handle_tool_error) content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error" status = "error"
except KeyboardInterrupt as e: except KeyboardInterrupt as e:
@@ -691,6 +701,7 @@ class ChildTool(BaseTool):
if not self.handle_tool_error: if not self.handle_tool_error:
error_to_raise = e error_to_raise = e
else: else:
e = convert_exception_to_tool_exception(e)
content = _handle_tool_error(e, flag=self.handle_tool_error) content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error" status = "error"
except KeyboardInterrupt as e: except KeyboardInterrupt as e:
@@ -735,7 +746,7 @@ def _handle_validation_error(
def _handle_tool_error( def _handle_tool_error(
e: Exception, e: ToolException,
*, *,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str: ) -> str: