core[patch]: support handle_tool_error=(Exception, ...) tuple

This commit is contained in:
Bagatur 2024-10-22 09:35:52 -07:00
parent 0640cbf2f1
commit 09e330064a
3 changed files with 182 additions and 37 deletions

View File

@ -19,8 +19,6 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
cast, cast,
get_args,
get_origin,
get_type_hints, get_type_hints,
) )
@ -37,6 +35,7 @@ from pydantic import (
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1 from pydantic.v1 import ValidationError as ValidationErrorV1
from pydantic.v1 import validate_arguments as validate_arguments_v1 from pydantic.v1 import validate_arguments as validate_arguments_v1
from typing_extensions import get_args, get_origin
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -333,13 +332,13 @@ class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
typehint_mandate = """ typehint_mandate = """
class ChildTool(BaseTool): class ChildTool(BaseTool):
... ...
args_schema: Type[BaseModel] = SchemaClass args_schema: type[BaseModel] = SchemaClass
...""" ..."""
name = cls.__name__ name = cls.__name__
msg = ( msg = (
f"Tool definition for {name} must include valid type annotations" f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n" f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'" f"Expected annotation of 'type[BaseModel]'"
f" but got '{args_schema_type}'.\n" f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n" f"Expected class looks like:\n"
f"{typehint_mandate}" f"{typehint_mandate}"
@ -399,9 +398,9 @@ class ChildTool(BaseTool):
You can use these to eg identify a specific instance of a tool with its use case. You can use these to eg identify a specific instance of a tool with its use case.
""" """
handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = ( handle_tool_error: Union[
False None, bool, str, Callable[..., str], tuple[type[Exception], ...]
) ] = False
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
handle_validation_error: Optional[ handle_validation_error: Optional[
@ -668,21 +667,33 @@ class ChildTool(BaseTool):
else: else:
content = response content = response
status = "success" status = "success"
except (ValidationError, ValidationErrorV1) as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error" status = "error"
# Validation error (args don't match pydantic schema)
if isinstance(e, (ValidationError, ValidationErrorV1)):
# Unhandled
if not self.handle_validation_error:
error_to_raise = e
# Handled
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
# Tool error (error raised at tool runtime)
else:
if isinstance(self.handle_tool_error, tuple):
handled_types: tuple = self.handle_tool_error
elif callable(self.handle_tool_error):
handled_types = _infer_handled_types(self.handle_tool_error)
else:
handled_types = (ToolException,)
# Unhandled
if not self.handle_tool_error or not isinstance(e, handled_types):
error_to_raise = e
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
if error_to_raise: if error_to_raise:
run_manager.on_tool_error(error_to_raise) run_manager.on_tool_error(error_to_raise)
@ -785,21 +796,32 @@ class ChildTool(BaseTool):
else: else:
content = response content = response
status = "success" status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error" status = "error"
# Validation error (args don't match pydantic schema)
if isinstance(e, (ValidationError, ValidationErrorV1)):
# Unhandled
if not self.handle_validation_error:
error_to_raise = e
# Handled
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
# Tool error (error raised at tool runtime)
else:
if isinstance(self.handle_tool_error, tuple):
handled_types: tuple = self.handle_tool_error
elif callable(self.handle_tool_error):
handled_types = _infer_handled_types(self.handle_tool_error)
else:
handled_types = (ToolException,)
# Unhandled
if not self.handle_tool_error or not isinstance(e, handled_types):
error_to_raise = e
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
if error_to_raise: if error_to_raise:
await run_manager.on_tool_error(error_to_raise) await run_manager.on_tool_error(error_to_raise)
@ -842,11 +864,17 @@ def _handle_validation_error(
def _handle_tool_error( def _handle_tool_error(
e: ToolException, e: Exception,
*, *,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], flag: Union[
None,
bool,
str,
Callable[..., str],
tuple[type[Exception], ...],
],
) -> str: ) -> str:
if isinstance(flag, bool): if isinstance(flag, (bool, tuple)):
content = e.args[0] if e.args else "Tool execution error" content = e.args[0] if e.args else "Tool execution error"
elif isinstance(flag, str): elif isinstance(flag, str):
content = flag content = flag
@ -854,7 +882,7 @@ def _handle_tool_error(
content = flag(e) content = flag(e)
else: else:
msg = ( msg = (
f"Got unexpected type of `handle_tool_error`. Expected bool, str " f"Got unexpected type of `handle_tool_error`. Expected bool, str, tuple, "
f"or callable. Received: {flag}" f"or callable. Received: {flag}"
) )
raise ValueError(msg) raise ValueError(msg)
@ -1050,3 +1078,24 @@ class BaseToolkit(BaseModel, ABC):
@abstractmethod @abstractmethod
def get_tools(self) -> list[BaseTool]: def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit.""" """Get the tools in the toolkit."""
def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
if params:
# If it's a method, the first argument is typically 'self' or 'cls'
if params[0].name in ["self", "cls"] and len(params) == 2:
first_param = params[1]
else:
first_param = params[0]
type_hints = get_type_hints(handler)
if first_param.name in type_hints:
if get_origin(first_param.annotation) is Union:
return tuple(get_args(first_param.annotation))
return (type_hints[first_param.name],)
# If no type information is available, return (ToolException,) for backwards
# compatibility.
return (ToolException,)

View File

@ -18,6 +18,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content", response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False, parse_docstring: bool = False,
error_on_invalid_docstring: bool = True, error_on_invalid_docstring: bool = True,
**kwargs: Any,
) -> Callable: ) -> Callable:
"""Make tools out of functions, can be used with or without arguments. """Make tools out of functions, can be used with or without arguments.
@ -42,6 +43,7 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings. whether to raise ValueError on invalid Google Style docstrings.
Defaults to True. Defaults to True.
kwargs: Additional keyword arguments to pass to BaseTool constructor.
Returns: Returns:
The tool. The tool.
@ -186,6 +188,7 @@ def tool(
response_format=response_format, response_format=response_format,
parse_docstring=parse_docstring, parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring, error_on_invalid_docstring=error_on_invalid_docstring,
**kwargs,
) )
# If someone doesn't want a schema applied, we must treat it as # If someone doesn't want a schema applied, we must treat it as
# a simple string->string function # a simple string->string function
@ -202,6 +205,7 @@ def tool(
return_direct=return_direct, return_direct=return_direct,
coroutine=coroutine, coroutine=coroutine,
response_format=response_format, response_format=response_format,
**kwargs,
) )
return _make_tool return _make_tool

View File

@ -48,6 +48,7 @@ from langchain_core.tools.base import (
InjectedToolArg, InjectedToolArg,
SchemaAnnotationError, SchemaAnnotationError,
_get_all_basemodel_annotations, _get_all_basemodel_annotations,
_infer_handled_types,
_is_message_content_block, _is_message_content_block,
_is_message_content_type, _is_message_content_type,
) )
@ -766,6 +767,97 @@ async def test_async_exception_handling_callable() -> None:
assert expected == actual assert expected == actual
def test_exception_handling_tuple() -> None:
@tool(handle_tool_error=(ValueError,))
def foo(x: int) -> int:
"""X"""
msg = "bar"
raise ValueError(msg)
actual = foo.invoke({"x": 0})
expected = "bar"
assert actual == expected
@tool(handle_tool_error=True)
def foo2(x: int) -> int:
"""X"""
msg = "bar"
raise ToolException(msg)
actual = foo2.invoke({"x": 0})
expected = "bar"
assert actual == expected
@tool(handle_tool_error=True)
def foo3(x: int) -> int:
"""X"""
msg = "bar"
raise ValueError(msg)
with pytest.raises(ValueError):
foo3.invoke({"x": 0})
def test_exception_handling_callable_type_hints() -> None:
def handle(e: ValueError) -> str:
return e.args[0]
@tool(handle_tool_error=handle)
def foo(x: int) -> int:
"""X"""
msg = "bar"
raise ValueError(msg)
actual = foo.invoke({"x": 0})
expected = "bar"
assert actual == expected
def handle2(e: Union[ToolException, KeyError]) -> str:
return e.args[0]
@tool(handle_tool_error=handle2)
def foo2(x: int) -> int:
"""X"""
msg = "bar"
raise ValueError(msg)
with pytest.raises(ValueError):
foo2.invoke({"x": 0})
def test__infer_handled_types() -> None:
def handle(e): # type: ignore
return ""
def handle2(e: Exception) -> str:
return ""
def handle3(e: Union[ValueError, ToolException]) -> str:
return ""
class Handler:
def handle(self, e: ValueError) -> str:
return ""
handle4 = Handler().handle
expected: tuple = (ToolException,)
actual = _infer_handled_types(handle)
assert expected == actual
expected = (Exception,)
actual = _infer_handled_types(handle2)
assert expected == actual
expected = (ValueError, ToolException)
actual = _infer_handled_types(handle3)
assert expected == actual
expected = (ValueError,)
actual = _infer_handled_types(handle4)
assert expected == actual
async def test_async_exception_handling_non_tool_exception() -> None: async def test_async_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError()) _tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError): with pytest.raises(ValueError):