diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 6436a4bf945..63d74e33c3e 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -239,15 +239,6 @@ class ToolException(Exception): pass -def convert_exception_to_tool_exception(e: Exception) -> ToolException: - tool_e = ToolException() - tool_e.args = e.args - tool_e.__cause__ = e.__cause__ - tool_e.__context__ = e.__context__ - tool_e.__traceback__ = e.__traceback__ - return tool_e - - class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]): """Interface LangChain tools must implement.""" @@ -552,44 +543,54 @@ class ChildTool(BaseTool): artifact = None error_to_raise: Union[Exception, KeyboardInterrupt, None] = None try: - child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) - if signature(self._run).parameters.get("run_manager"): - tool_kwargs["run_manager"] = run_manager + try: + child_config = patch_config(config, callbacks=run_manager.get_child()) + context = copy_context() + context.run(_set_config_context, child_config) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) + if signature(self._run).parameters.get("run_manager"): + tool_kwargs["run_manager"] = run_manager - if config_param := _get_runnable_config_param(self._run): - tool_kwargs[config_param] = config - response = context.run(self._run, *tool_args, **tool_kwargs) - if self.response_format == "content_and_artifact": - if not isinstance(response, tuple) or len(response) != 2: - raise ValueError( - "Since response_format='content_and_artifact' " - "a two-tuple of the message content and raw tool output is " - f"expected. Instead generated response of type: " - f"{type(response)}." + if config_param := _get_runnable_config_param(self._run): + tool_kwargs[config_param] = config + response = context.run(self._run, *tool_args, **tool_kwargs) + if self.response_format == "content_and_artifact": + if not isinstance(response, tuple) or len(response) != 2: + raise ValueError( + "Since response_format='content_and_artifact' " + "a two-tuple of the message content and raw tool output is " + f"expected. Instead generated response of type: " + f"{type(response)}." + ) + content, artifact = response + else: + content = response + status = "success" + except ValidationError as e: + if not self.handle_validation_error: + error_to_raise = e + else: + content = _handle_validation_error( + e, flag=self.handle_validation_error ) - content, artifact = response - else: - content = response - status = "success" - except ValidationError as e: - if not self.handle_validation_error: + status = "error" + except ToolException as e: + if not self.handle_tool_error: + error_to_raise = e + else: + content = _handle_tool_error(e, flag=self.handle_tool_error) + status = "error" + except Exception as e: + raise ToolException(str(e)) from e + except KeyboardInterrupt as e: error_to_raise = e - else: - content = _handle_validation_error(e, flag=self.handle_validation_error) - status = "error" - except Exception as e: + status = "error" + except ToolException as e: if not self.handle_tool_error: error_to_raise = e else: - e = convert_exception_to_tool_exception(e) content = _handle_tool_error(e, flag=self.handle_tool_error) status = "error" - except KeyboardInterrupt as e: - error_to_raise = e - status = "error" if error_to_raise: run_manager.on_tool_error(error_to_raise) @@ -662,51 +663,61 @@ class ChildTool(BaseTool): artifact = None error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None try: - tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) - child_config = patch_config(config, callbacks=run_manager.get_child()) - context = copy_context() - context.run(_set_config_context, child_config) - func_to_check = ( - self._run if self.__class__._arun is BaseTool._arun else self._arun - ) - if signature(func_to_check).parameters.get("run_manager"): - tool_kwargs["run_manager"] = run_manager - if config_param := _get_runnable_config_param(func_to_check): - tool_kwargs[config_param] = config + try: + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) + child_config = patch_config(config, callbacks=run_manager.get_child()) + context = copy_context() + context.run(_set_config_context, child_config) + func_to_check = ( + self._run if self.__class__._arun is BaseTool._arun else self._arun + ) + if signature(func_to_check).parameters.get("run_manager"): + tool_kwargs["run_manager"] = run_manager + if config_param := _get_runnable_config_param(func_to_check): + tool_kwargs[config_param] = config - coro = context.run(self._arun, *tool_args, **tool_kwargs) - if asyncio_accepts_context(): - response = await asyncio.create_task(coro, context=context) # type: ignore - else: - response = await coro - if self.response_format == "content_and_artifact": - if not isinstance(response, tuple) or len(response) != 2: - raise ValueError( - "Since response_format='content_and_artifact' " - "a two-tuple of the message content and raw tool output is " - f"expected. Instead generated response of type: " - f"{type(response)}." + coro = context.run(self._arun, *tool_args, **tool_kwargs) + if asyncio_accepts_context(): + response = await asyncio.create_task(coro, context=context) # type: ignore + else: + response = await coro + if self.response_format == "content_and_artifact": + if not isinstance(response, tuple) or len(response) != 2: + raise ValueError( + "Since response_format='content_and_artifact' " + "a two-tuple of the message content and raw tool output is " + f"expected. Instead generated response of type: " + f"{type(response)}." + ) + content, artifact = response + else: + content = response + status = "success" + except ValidationError as e: + if not self.handle_validation_error: + error_to_raise = e + else: + content = _handle_validation_error( + e, flag=self.handle_validation_error ) - content, artifact = response - else: - content = response - status = "success" - except ValidationError as e: - if not self.handle_validation_error: + status = "error" + except ToolException as e: + if not self.handle_tool_error: + error_to_raise = e + else: + content = _handle_tool_error(e, flag=self.handle_tool_error) + status = "error" + except Exception as e: + raise ToolException(str(e)) from e + except KeyboardInterrupt as e: error_to_raise = e - else: - content = _handle_validation_error(e, flag=self.handle_validation_error) - status = "error" - except Exception as e: + status = "error" + except ToolException as e: if not self.handle_tool_error: error_to_raise = e else: - e = convert_exception_to_tool_exception(e) content = _handle_tool_error(e, flag=self.handle_tool_error) status = "error" - except KeyboardInterrupt as e: - error_to_raise = e - status = "error" if error_to_raise: await run_manager.on_tool_error(error_to_raise) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 39cf811fc5e..746bbedcc2a 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -679,7 +679,7 @@ def test_exception_handling_callable() -> None: def test_exception_handling_non_tool_exception() -> None: _tool = _FakeExceptionTool(exception=ValueError()) - with pytest.raises(ValueError): + with pytest.raises(ToolException): _tool.run({}) @@ -710,7 +710,7 @@ async def test_async_exception_handling_callable() -> None: async def test_async_exception_handling_non_tool_exception() -> None: _tool = _FakeExceptionTool(exception=ValueError()) - with pytest.raises(ValueError): + with pytest.raises(ToolException): await _tool.arun({}) @@ -806,7 +806,7 @@ def test_validation_error_handling_non_validation_error( return "dummy" _tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg] - with pytest.raises(NotImplementedError): + with pytest.raises(ToolException): _tool.run({}) @@ -868,7 +868,7 @@ async def test_async_validation_error_handling_non_validation_error( return "dummy" _tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg] - with pytest.raises(NotImplementedError): + with pytest.raises(ToolException): await _tool.arun({}) @@ -1408,10 +1408,8 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None: assert tool_.invoke( {"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"} ) == ToolMessage("bar", tool_call_id="123", name="foo") - expected_error = ( - ValidationError if not isinstance(tool_, InjectedTool) else TypeError - ) - with pytest.raises(expected_error): + + with pytest.raises(ToolException): tool_.invoke({"x": 5}) assert convert_to_openai_function(tool_) == { @@ -1873,7 +1871,6 @@ def test__get_all_basemodel_annotations_v1() -> None: def test_tool_args_schema_pydantic_v2_with_metadata() -> None: from pydantic import BaseModel as BaseModelV2 # pydantic: ignore from pydantic import Field as FieldV2 # pydantic: ignore - from pydantic import ValidationError as ValidationErrorV2 # pydantic: ignore class Foo(BaseModelV2): x: List[int] = FieldV2( @@ -1903,7 +1900,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None: } assert foo.invoke({"x": [0] * 10}) - with pytest.raises(ValidationErrorV2): + with pytest.raises(ToolException): foo.invoke({"x": [0] * 9})