catch exceptions properly

This commit is contained in:
isaac hershenson
2024-09-11 11:59:14 -07:00
parent c50cd99d99
commit 1865e737d0
2 changed files with 94 additions and 86 deletions

View File

@@ -239,15 +239,6 @@ class ToolException(Exception):
pass pass
def convert_exception_to_tool_exception(e: Exception) -> ToolException:
tool_e = ToolException()
tool_e.args = e.args
tool_e.__cause__ = e.__cause__
tool_e.__context__ = e.__context__
tool_e.__traceback__ = e.__traceback__
return tool_e
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]): class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
"""Interface LangChain tools must implement.""" """Interface LangChain tools must implement."""
@@ -552,44 +543,54 @@ class ChildTool(BaseTool):
artifact = None artifact = None
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try: try:
child_config = patch_config(config, callbacks=run_manager.get_child()) try:
context = copy_context() child_config = patch_config(config, callbacks=run_manager.get_child())
context.run(_set_config_context, child_config) context = copy_context()
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) context.run(_set_config_context, child_config)
if signature(self._run).parameters.get("run_manager"): tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
tool_kwargs["run_manager"] = run_manager if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(self._run): if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs) response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact": if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2: if not isinstance(response, tuple) or len(response) != 2:
raise ValueError( raise ValueError(
"Since response_format='content_and_artifact' " "Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is " "a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: " f"expected. Instead generated response of type: "
f"{type(response)}." f"{type(response)}."
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
) )
content, artifact = response status = "error"
else: except ToolException as e:
content = response if not self.handle_tool_error:
status = "success" error_to_raise = e
except ValidationError as e: else:
if not self.handle_validation_error: content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except Exception as e:
raise ToolException(str(e)) from e
except KeyboardInterrupt as e:
error_to_raise = e error_to_raise = e
else: status = "error"
content = _handle_validation_error(e, flag=self.handle_validation_error) except ToolException as e:
status = "error"
except Exception as e:
if not self.handle_tool_error: if not self.handle_tool_error:
error_to_raise = e error_to_raise = e
else: else:
e = convert_exception_to_tool_exception(e)
content = _handle_tool_error(e, flag=self.handle_tool_error) content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error" status = "error"
except KeyboardInterrupt as e:
error_to_raise = e
status = "error"
if error_to_raise: if error_to_raise:
run_manager.on_tool_error(error_to_raise) run_manager.on_tool_error(error_to_raise)
@@ -662,51 +663,61 @@ class ChildTool(BaseTool):
artifact = None artifact = None
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try: try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) try:
child_config = patch_config(config, callbacks=run_manager.get_child()) tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
context = copy_context() child_config = patch_config(config, callbacks=run_manager.get_child())
context.run(_set_config_context, child_config) context = copy_context()
func_to_check = ( context.run(_set_config_context, child_config)
self._run if self.__class__._arun is BaseTool._arun else self._arun func_to_check = (
) self._run if self.__class__._arun is BaseTool._arun else self._arun
if signature(func_to_check).parameters.get("run_manager"): )
tool_kwargs["run_manager"] = run_manager if signature(func_to_check).parameters.get("run_manager"):
if config_param := _get_runnable_config_param(func_to_check): tool_kwargs["run_manager"] = run_manager
tool_kwargs[config_param] = config if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs) coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context(): if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore response = await asyncio.create_task(coro, context=context) # type: ignore
else: else:
response = await coro response = await coro
if self.response_format == "content_and_artifact": if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2: if not isinstance(response, tuple) or len(response) != 2:
raise ValueError( raise ValueError(
"Since response_format='content_and_artifact' " "Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is " "a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: " f"expected. Instead generated response of type: "
f"{type(response)}." f"{type(response)}."
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
) )
content, artifact = response status = "error"
else: except ToolException as e:
content = response if not self.handle_tool_error:
status = "success" error_to_raise = e
except ValidationError as e: else:
if not self.handle_validation_error: content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except Exception as e:
raise ToolException(str(e)) from e
except KeyboardInterrupt as e:
error_to_raise = e error_to_raise = e
else: status = "error"
content = _handle_validation_error(e, flag=self.handle_validation_error) except ToolException as e:
status = "error"
except Exception as e:
if not self.handle_tool_error: if not self.handle_tool_error:
error_to_raise = e error_to_raise = e
else: else:
e = convert_exception_to_tool_exception(e)
content = _handle_tool_error(e, flag=self.handle_tool_error) content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error" status = "error"
except KeyboardInterrupt as e:
error_to_raise = e
status = "error"
if error_to_raise: if error_to_raise:
await run_manager.on_tool_error(error_to_raise) await run_manager.on_tool_error(error_to_raise)

View File

@@ -679,7 +679,7 @@ def test_exception_handling_callable() -> None:
def test_exception_handling_non_tool_exception() -> None: def test_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError()) _tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError): with pytest.raises(ToolException):
_tool.run({}) _tool.run({})
@@ -710,7 +710,7 @@ async def test_async_exception_handling_callable() -> None:
async def test_async_exception_handling_non_tool_exception() -> None: async def test_async_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError()) _tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError): with pytest.raises(ToolException):
await _tool.arun({}) await _tool.arun({})
@@ -806,7 +806,7 @@ def test_validation_error_handling_non_validation_error(
return "dummy" return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg] _tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError): with pytest.raises(ToolException):
_tool.run({}) _tool.run({})
@@ -868,7 +868,7 @@ async def test_async_validation_error_handling_non_validation_error(
return "dummy" return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg] _tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError): with pytest.raises(ToolException):
await _tool.arun({}) await _tool.arun({})
@@ -1408,10 +1408,8 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
assert tool_.invoke( assert tool_.invoke(
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"} {"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
) == ToolMessage("bar", tool_call_id="123", name="foo") ) == ToolMessage("bar", tool_call_id="123", name="foo")
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError with pytest.raises(ToolException):
)
with pytest.raises(expected_error):
tool_.invoke({"x": 5}) tool_.invoke({"x": 5})
assert convert_to_openai_function(tool_) == { assert convert_to_openai_function(tool_) == {
@@ -1873,7 +1871,6 @@ def test__get_all_basemodel_annotations_v1() -> None:
def test_tool_args_schema_pydantic_v2_with_metadata() -> None: def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic import Field as FieldV2 # pydantic: ignore from pydantic import Field as FieldV2 # pydantic: ignore
from pydantic import ValidationError as ValidationErrorV2 # pydantic: ignore
class Foo(BaseModelV2): class Foo(BaseModelV2):
x: List[int] = FieldV2( x: List[int] = FieldV2(
@@ -1903,7 +1900,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
} }
assert foo.invoke({"x": [0] * 10}) assert foo.invoke({"x": [0] * 10})
with pytest.raises(ValidationErrorV2): with pytest.raises(ToolException):
foo.invoke({"x": [0] * 9}) foo.invoke({"x": [0] * 9})