core[patch]: support handle_tool_error=(Exception, ...) tuple

This commit is contained in:
Bagatur 2024-10-22 09:35:52 -07:00
parent 0640cbf2f1
commit 09e330064a
3 changed files with 182 additions and 37 deletions

View File

@ -19,8 +19,6 @@ from typing import (
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
)
@ -37,6 +35,7 @@ from pydantic import (
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from pydantic.v1 import validate_arguments as validate_arguments_v1
from typing_extensions import get_args, get_origin
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -333,13 +332,13 @@ class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
typehint_mandate = """
class ChildTool(BaseTool):
...
args_schema: Type[BaseModel] = SchemaClass
args_schema: type[BaseModel] = SchemaClass
..."""
name = cls.__name__
msg = (
f"Tool definition for {name} must include valid type annotations"
f" for argument 'args_schema' to behave as expected.\n"
f"Expected annotation of 'Type[BaseModel]'"
f"Expected annotation of 'type[BaseModel]'"
f" but got '{args_schema_type}'.\n"
f"Expected class looks like:\n"
f"{typehint_mandate}"
@ -399,9 +398,9 @@ class ChildTool(BaseTool):
You can use these to eg identify a specific instance of a tool with its use case.
"""
handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = (
False
)
handle_tool_error: Union[
None, bool, str, Callable[..., str], tuple[type[Exception], ...]
] = False
"""Handle the content of the ToolException thrown."""
handle_validation_error: Optional[
@ -668,21 +667,33 @@ class ChildTool(BaseTool):
else:
content = response
status = "success"
except (ValidationError, ValidationErrorV1) as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
# Validation error (args don't match pydantic schema)
if isinstance(e, (ValidationError, ValidationErrorV1)):
# Unhandled
if not self.handle_validation_error:
error_to_raise = e
# Handled
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
# Tool error (error raised at tool runtime)
else:
if isinstance(self.handle_tool_error, tuple):
handled_types: tuple = self.handle_tool_error
elif callable(self.handle_tool_error):
handled_types = _infer_handled_types(self.handle_tool_error)
else:
handled_types = (ToolException,)
# Unhandled
if not self.handle_tool_error or not isinstance(e, handled_types):
error_to_raise = e
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
@ -785,21 +796,32 @@ class ChildTool(BaseTool):
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
# Validation error (args don't match pydantic schema)
if isinstance(e, (ValidationError, ValidationErrorV1)):
# Unhandled
if not self.handle_validation_error:
error_to_raise = e
# Handled
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
# Tool error (error raised at tool runtime)
else:
if isinstance(self.handle_tool_error, tuple):
handled_types: tuple = self.handle_tool_error
elif callable(self.handle_tool_error):
handled_types = _infer_handled_types(self.handle_tool_error)
else:
handled_types = (ToolException,)
# Unhandled
if not self.handle_tool_error or not isinstance(e, handled_types):
error_to_raise = e
# Handled
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
if error_to_raise:
await run_manager.on_tool_error(error_to_raise)
@ -842,11 +864,17 @@ def _handle_validation_error(
def _handle_tool_error(
e: ToolException,
e: Exception,
*,
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
flag: Union[
None,
bool,
str,
Callable[..., str],
tuple[type[Exception], ...],
],
) -> str:
if isinstance(flag, bool):
if isinstance(flag, (bool, tuple)):
content = e.args[0] if e.args else "Tool execution error"
elif isinstance(flag, str):
content = flag
@ -854,7 +882,7 @@ def _handle_tool_error(
content = flag(e)
else:
msg = (
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"Got unexpected type of `handle_tool_error`. Expected bool, str, tuple, "
f"or callable. Received: {flag}"
)
raise ValueError(msg)
@ -1050,3 +1078,24 @@ class BaseToolkit(BaseModel, ABC):
@abstractmethod
def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit."""
def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
if params:
# If it's a method, the first argument is typically 'self' or 'cls'
if params[0].name in ["self", "cls"] and len(params) == 2:
first_param = params[1]
else:
first_param = params[0]
type_hints = get_type_hints(handler)
if first_param.name in type_hints:
if get_origin(first_param.annotation) is Union:
return tuple(get_args(first_param.annotation))
return (type_hints[first_param.name],)
# If no type information is available, return (ToolException,) for backwards
# compatibility.
return (ToolException,)

View File

@ -18,6 +18,7 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
**kwargs: Any,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@ -42,6 +43,7 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
kwargs: Additional keyword arguments to pass to BaseTool constructor.
Returns:
The tool.
@ -186,6 +188,7 @@ def tool(
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
**kwargs,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
@ -202,6 +205,7 @@ def tool(
return_direct=return_direct,
coroutine=coroutine,
response_format=response_format,
**kwargs,
)
return _make_tool

View File

@ -48,6 +48,7 @@ from langchain_core.tools.base import (
InjectedToolArg,
SchemaAnnotationError,
_get_all_basemodel_annotations,
_infer_handled_types,
_is_message_content_block,
_is_message_content_type,
)
@ -766,6 +767,97 @@ async def test_async_exception_handling_callable() -> None:
assert expected == actual
def test_exception_handling_tuple() -> None:
@tool(handle_tool_error=(ValueError,))
def foo(x: int) -> int:
"""X"""
msg = "bar"
raise ValueError(msg)
actual = foo.invoke({"x": 0})
expected = "bar"
assert actual == expected
@tool(handle_tool_error=True)
def foo2(x: int) -> int:
"""X"""
msg = "bar"
raise ToolException(msg)
actual = foo2.invoke({"x": 0})
expected = "bar"
assert actual == expected
@tool(handle_tool_error=True)
def foo3(x: int) -> int:
"""X"""
msg = "bar"
raise ValueError(msg)
with pytest.raises(ValueError):
foo3.invoke({"x": 0})
def test_exception_handling_callable_type_hints() -> None:
def handle(e: ValueError) -> str:
return e.args[0]
@tool(handle_tool_error=handle)
def foo(x: int) -> int:
"""X"""
msg = "bar"
raise ValueError(msg)
actual = foo.invoke({"x": 0})
expected = "bar"
assert actual == expected
def handle2(e: Union[ToolException, KeyError]) -> str:
return e.args[0]
@tool(handle_tool_error=handle2)
def foo2(x: int) -> int:
"""X"""
msg = "bar"
raise ValueError(msg)
with pytest.raises(ValueError):
foo2.invoke({"x": 0})
def test__infer_handled_types() -> None:
def handle(e): # type: ignore
return ""
def handle2(e: Exception) -> str:
return ""
def handle3(e: Union[ValueError, ToolException]) -> str:
return ""
class Handler:
def handle(self, e: ValueError) -> str:
return ""
handle4 = Handler().handle
expected: tuple = (ToolException,)
actual = _infer_handled_types(handle)
assert expected == actual
expected = (Exception,)
actual = _infer_handled_types(handle2)
assert expected == actual
expected = (ValueError, ToolException)
actual = _infer_handled_types(handle3)
assert expected == actual
expected = (ValueError,)
actual = _infer_handled_types(handle4)
assert expected == actual
async def test_async_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):