diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 4de0452020f..423ada27bc1 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -19,8 +19,6 @@ from typing import ( TypeVar, Union, cast, - get_args, - get_origin, get_type_hints, ) @@ -37,6 +35,7 @@ from pydantic import ( from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import ValidationError as ValidationErrorV1 from pydantic.v1 import validate_arguments as validate_arguments_v1 +from typing_extensions import get_args, get_origin from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -333,13 +332,13 @@ class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]): typehint_mandate = """ class ChildTool(BaseTool): ... - args_schema: Type[BaseModel] = SchemaClass + args_schema: type[BaseModel] = SchemaClass ...""" name = cls.__name__ msg = ( f"Tool definition for {name} must include valid type annotations" f" for argument 'args_schema' to behave as expected.\n" - f"Expected annotation of 'Type[BaseModel]'" + f"Expected annotation of 'type[BaseModel]'" f" but got '{args_schema_type}'.\n" f"Expected class looks like:\n" f"{typehint_mandate}" @@ -399,9 +398,9 @@ class ChildTool(BaseTool): You can use these to eg identify a specific instance of a tool with its use case. """ - handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = ( - False - ) + handle_tool_error: Union[ + None, bool, str, Callable[..., str], tuple[type[Exception], ...] + ] = False """Handle the content of the ToolException thrown.""" handle_validation_error: Optional[ @@ -668,21 +667,33 @@ class ChildTool(BaseTool): else: content = response status = "success" - except (ValidationError, ValidationErrorV1) as e: - if not self.handle_validation_error: - error_to_raise = e - else: - content = _handle_validation_error(e, flag=self.handle_validation_error) - status = "error" - except ToolException as e: - if not self.handle_tool_error: - error_to_raise = e - else: - content = _handle_tool_error(e, flag=self.handle_tool_error) - status = "error" except (Exception, KeyboardInterrupt) as e: - error_to_raise = e status = "error" + # Validation error (args don't match pydantic schema) + if isinstance(e, (ValidationError, ValidationErrorV1)): + # Unhandled + if not self.handle_validation_error: + error_to_raise = e + # Handled + else: + content = _handle_validation_error( + e, flag=self.handle_validation_error + ) + # Tool error (error raised at tool runtime) + else: + if isinstance(self.handle_tool_error, tuple): + handled_types: tuple = self.handle_tool_error + elif callable(self.handle_tool_error): + handled_types = _infer_handled_types(self.handle_tool_error) + else: + handled_types = (ToolException,) + + # Unhandled + if not self.handle_tool_error or not isinstance(e, handled_types): + error_to_raise = e + # Handled + else: + content = _handle_tool_error(e, flag=self.handle_tool_error) if error_to_raise: run_manager.on_tool_error(error_to_raise) @@ -785,21 +796,32 @@ class ChildTool(BaseTool): else: content = response status = "success" - except ValidationError as e: - if not self.handle_validation_error: - error_to_raise = e - else: - content = _handle_validation_error(e, flag=self.handle_validation_error) - status = "error" - except ToolException as e: - if not self.handle_tool_error: - error_to_raise = e - else: - content = _handle_tool_error(e, flag=self.handle_tool_error) - status = "error" except (Exception, KeyboardInterrupt) as e: - error_to_raise = e status = "error" + # Validation error (args don't match pydantic schema) + if isinstance(e, (ValidationError, ValidationErrorV1)): + # Unhandled + if not self.handle_validation_error: + error_to_raise = e + # Handled + else: + content = _handle_validation_error( + e, flag=self.handle_validation_error + ) + # Tool error (error raised at tool runtime) + else: + if isinstance(self.handle_tool_error, tuple): + handled_types: tuple = self.handle_tool_error + elif callable(self.handle_tool_error): + handled_types = _infer_handled_types(self.handle_tool_error) + else: + handled_types = (ToolException,) + # Unhandled + if not self.handle_tool_error or not isinstance(e, handled_types): + error_to_raise = e + # Handled + else: + content = _handle_tool_error(e, flag=self.handle_tool_error) if error_to_raise: await run_manager.on_tool_error(error_to_raise) @@ -842,11 +864,17 @@ def _handle_validation_error( def _handle_tool_error( - e: ToolException, + e: Exception, *, - flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], + flag: Union[ + None, + bool, + str, + Callable[..., str], + tuple[type[Exception], ...], + ], ) -> str: - if isinstance(flag, bool): + if isinstance(flag, (bool, tuple)): content = e.args[0] if e.args else "Tool execution error" elif isinstance(flag, str): content = flag @@ -854,7 +882,7 @@ def _handle_tool_error( content = flag(e) else: msg = ( - f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"Got unexpected type of `handle_tool_error`. Expected bool, str, tuple, " f"or callable. Received: {flag}" ) raise ValueError(msg) @@ -1050,3 +1078,24 @@ class BaseToolkit(BaseModel, ABC): @abstractmethod def get_tools(self) -> list[BaseTool]: """Get the tools in the toolkit.""" + + +def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]: + sig = inspect.signature(handler) + params = list(sig.parameters.values()) + if params: + # If it's a method, the first argument is typically 'self' or 'cls' + if params[0].name in ["self", "cls"] and len(params) == 2: + first_param = params[1] + else: + first_param = params[0] + + type_hints = get_type_hints(handler) + if first_param.name in type_hints: + if get_origin(first_param.annotation) is Union: + return tuple(get_args(first_param.annotation)) + return (type_hints[first_param.name],) + + # If no type information is available, return (ToolException,) for backwards + # compatibility. + return (ToolException,) diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index e85435a86df..4b3dd305c84 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -18,6 +18,7 @@ def tool( response_format: Literal["content", "content_and_artifact"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, + **kwargs: Any, ) -> Callable: """Make tools out of functions, can be used with or without arguments. @@ -42,6 +43,7 @@ def tool( error_on_invalid_docstring: if ``parse_docstring`` is provided, configure whether to raise ValueError on invalid Google Style docstrings. Defaults to True. + kwargs: Additional keyword arguments to pass to BaseTool constructor. Returns: The tool. @@ -186,6 +188,7 @@ def tool( response_format=response_format, parse_docstring=parse_docstring, error_on_invalid_docstring=error_on_invalid_docstring, + **kwargs, ) # If someone doesn't want a schema applied, we must treat it as # a simple string->string function @@ -202,6 +205,7 @@ def tool( return_direct=return_direct, coroutine=coroutine, response_format=response_format, + **kwargs, ) return _make_tool diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 065e36c8668..07648a5eb77 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -48,6 +48,7 @@ from langchain_core.tools.base import ( InjectedToolArg, SchemaAnnotationError, _get_all_basemodel_annotations, + _infer_handled_types, _is_message_content_block, _is_message_content_type, ) @@ -766,6 +767,97 @@ async def test_async_exception_handling_callable() -> None: assert expected == actual +def test_exception_handling_tuple() -> None: + @tool(handle_tool_error=(ValueError,)) + def foo(x: int) -> int: + """X""" + msg = "bar" + raise ValueError(msg) + + actual = foo.invoke({"x": 0}) + expected = "bar" + assert actual == expected + + @tool(handle_tool_error=True) + def foo2(x: int) -> int: + """X""" + msg = "bar" + raise ToolException(msg) + + actual = foo2.invoke({"x": 0}) + expected = "bar" + assert actual == expected + + @tool(handle_tool_error=True) + def foo3(x: int) -> int: + """X""" + msg = "bar" + raise ValueError(msg) + + with pytest.raises(ValueError): + foo3.invoke({"x": 0}) + + +def test_exception_handling_callable_type_hints() -> None: + def handle(e: ValueError) -> str: + return e.args[0] + + @tool(handle_tool_error=handle) + def foo(x: int) -> int: + """X""" + msg = "bar" + raise ValueError(msg) + + actual = foo.invoke({"x": 0}) + expected = "bar" + assert actual == expected + + def handle2(e: Union[ToolException, KeyError]) -> str: + return e.args[0] + + @tool(handle_tool_error=handle2) + def foo2(x: int) -> int: + """X""" + msg = "bar" + raise ValueError(msg) + + with pytest.raises(ValueError): + foo2.invoke({"x": 0}) + + +def test__infer_handled_types() -> None: + def handle(e): # type: ignore + return "" + + def handle2(e: Exception) -> str: + return "" + + def handle3(e: Union[ValueError, ToolException]) -> str: + return "" + + class Handler: + def handle(self, e: ValueError) -> str: + return "" + + handle4 = Handler().handle + + expected: tuple = (ToolException,) + actual = _infer_handled_types(handle) + assert expected == actual + + expected = (Exception,) + actual = _infer_handled_types(handle2) + assert expected == actual + + expected = (ValueError, ToolException) + actual = _infer_handled_types(handle3) + assert expected == actual + + expected = (ValueError,) + actual = _infer_handled_types(handle4) + assert expected == actual + + async def test_async_exception_handling_non_tool_exception() -> None: _tool = _FakeExceptionTool(exception=ValueError()) with pytest.raises(ValueError):