mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
core[patch]: support handle_tool_error=(Exception, ...) tuple
This commit is contained in:
parent
0640cbf2f1
commit
09e330064a
@ -19,8 +19,6 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
@ -37,6 +35,7 @@ from pydantic import (
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -333,13 +332,13 @@ class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
||||
typehint_mandate = """
|
||||
class ChildTool(BaseTool):
|
||||
...
|
||||
args_schema: Type[BaseModel] = SchemaClass
|
||||
args_schema: type[BaseModel] = SchemaClass
|
||||
..."""
|
||||
name = cls.__name__
|
||||
msg = (
|
||||
f"Tool definition for {name} must include valid type annotations"
|
||||
f" for argument 'args_schema' to behave as expected.\n"
|
||||
f"Expected annotation of 'Type[BaseModel]'"
|
||||
f"Expected annotation of 'type[BaseModel]'"
|
||||
f" but got '{args_schema_type}'.\n"
|
||||
f"Expected class looks like:\n"
|
||||
f"{typehint_mandate}"
|
||||
@ -399,9 +398,9 @@ class ChildTool(BaseTool):
|
||||
You can use these to eg identify a specific instance of a tool with its use case.
|
||||
"""
|
||||
|
||||
handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = (
|
||||
False
|
||||
)
|
||||
handle_tool_error: Union[
|
||||
None, bool, str, Callable[..., str], tuple[type[Exception], ...]
|
||||
] = False
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
|
||||
handle_validation_error: Optional[
|
||||
@ -668,21 +667,33 @@ class ChildTool(BaseTool):
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except (ValidationError, ValidationErrorV1) as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
status = "error"
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
error_to_raise = e
|
||||
status = "error"
|
||||
# Validation error (args don't match pydantic schema)
|
||||
if isinstance(e, (ValidationError, ValidationErrorV1)):
|
||||
# Unhandled
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
# Handled
|
||||
else:
|
||||
content = _handle_validation_error(
|
||||
e, flag=self.handle_validation_error
|
||||
)
|
||||
# Tool error (error raised at tool runtime)
|
||||
else:
|
||||
if isinstance(self.handle_tool_error, tuple):
|
||||
handled_types: tuple = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
handled_types = _infer_handled_types(self.handle_tool_error)
|
||||
else:
|
||||
handled_types = (ToolException,)
|
||||
|
||||
# Unhandled
|
||||
if not self.handle_tool_error or not isinstance(e, handled_types):
|
||||
error_to_raise = e
|
||||
# Handled
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
|
||||
if error_to_raise:
|
||||
run_manager.on_tool_error(error_to_raise)
|
||||
@ -785,21 +796,32 @@ class ChildTool(BaseTool):
|
||||
else:
|
||||
content = response
|
||||
status = "success"
|
||||
except ValidationError as e:
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||
status = "error"
|
||||
except ToolException as e:
|
||||
if not self.handle_tool_error:
|
||||
error_to_raise = e
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
status = "error"
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
error_to_raise = e
|
||||
status = "error"
|
||||
# Validation error (args don't match pydantic schema)
|
||||
if isinstance(e, (ValidationError, ValidationErrorV1)):
|
||||
# Unhandled
|
||||
if not self.handle_validation_error:
|
||||
error_to_raise = e
|
||||
# Handled
|
||||
else:
|
||||
content = _handle_validation_error(
|
||||
e, flag=self.handle_validation_error
|
||||
)
|
||||
# Tool error (error raised at tool runtime)
|
||||
else:
|
||||
if isinstance(self.handle_tool_error, tuple):
|
||||
handled_types: tuple = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
handled_types = _infer_handled_types(self.handle_tool_error)
|
||||
else:
|
||||
handled_types = (ToolException,)
|
||||
# Unhandled
|
||||
if not self.handle_tool_error or not isinstance(e, handled_types):
|
||||
error_to_raise = e
|
||||
# Handled
|
||||
else:
|
||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||
|
||||
if error_to_raise:
|
||||
await run_manager.on_tool_error(error_to_raise)
|
||||
@ -842,11 +864,17 @@ def _handle_validation_error(
|
||||
|
||||
|
||||
def _handle_tool_error(
|
||||
e: ToolException,
|
||||
e: Exception,
|
||||
*,
|
||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||
flag: Union[
|
||||
None,
|
||||
bool,
|
||||
str,
|
||||
Callable[..., str],
|
||||
tuple[type[Exception], ...],
|
||||
],
|
||||
) -> str:
|
||||
if isinstance(flag, bool):
|
||||
if isinstance(flag, (bool, tuple)):
|
||||
content = e.args[0] if e.args else "Tool execution error"
|
||||
elif isinstance(flag, str):
|
||||
content = flag
|
||||
@ -854,7 +882,7 @@ def _handle_tool_error(
|
||||
content = flag(e)
|
||||
else:
|
||||
msg = (
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str, tuple, "
|
||||
f"or callable. Received: {flag}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
@ -1050,3 +1078,24 @@ class BaseToolkit(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
|
||||
|
||||
def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]:
|
||||
sig = inspect.signature(handler)
|
||||
params = list(sig.parameters.values())
|
||||
if params:
|
||||
# If it's a method, the first argument is typically 'self' or 'cls'
|
||||
if params[0].name in ["self", "cls"] and len(params) == 2:
|
||||
first_param = params[1]
|
||||
else:
|
||||
first_param = params[0]
|
||||
|
||||
type_hints = get_type_hints(handler)
|
||||
if first_param.name in type_hints:
|
||||
if get_origin(first_param.annotation) is Union:
|
||||
return tuple(get_args(first_param.annotation))
|
||||
return (type_hints[first_param.name],)
|
||||
|
||||
# If no type information is available, return (ToolException,) for backwards
|
||||
# compatibility.
|
||||
return (ToolException,)
|
||||
|
@ -18,6 +18,7 @@ def tool(
|
||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||
parse_docstring: bool = False,
|
||||
error_on_invalid_docstring: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Callable:
|
||||
"""Make tools out of functions, can be used with or without arguments.
|
||||
|
||||
@ -42,6 +43,7 @@ def tool(
|
||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
||||
whether to raise ValueError on invalid Google Style docstrings.
|
||||
Defaults to True.
|
||||
kwargs: Additional keyword arguments to pass to BaseTool constructor.
|
||||
|
||||
Returns:
|
||||
The tool.
|
||||
@ -186,6 +188,7 @@ def tool(
|
||||
response_format=response_format,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
**kwargs,
|
||||
)
|
||||
# If someone doesn't want a schema applied, we must treat it as
|
||||
# a simple string->string function
|
||||
@ -202,6 +205,7 @@ def tool(
|
||||
return_direct=return_direct,
|
||||
coroutine=coroutine,
|
||||
response_format=response_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
|
@ -48,6 +48,7 @@ from langchain_core.tools.base import (
|
||||
InjectedToolArg,
|
||||
SchemaAnnotationError,
|
||||
_get_all_basemodel_annotations,
|
||||
_infer_handled_types,
|
||||
_is_message_content_block,
|
||||
_is_message_content_type,
|
||||
)
|
||||
@ -766,6 +767,97 @@ async def test_async_exception_handling_callable() -> None:
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_exception_handling_tuple() -> None:
|
||||
@tool(handle_tool_error=(ValueError,))
|
||||
def foo(x: int) -> int:
|
||||
"""X"""
|
||||
msg = "bar"
|
||||
raise ValueError(msg)
|
||||
|
||||
actual = foo.invoke({"x": 0})
|
||||
expected = "bar"
|
||||
assert actual == expected
|
||||
|
||||
@tool(handle_tool_error=True)
|
||||
def foo2(x: int) -> int:
|
||||
"""X"""
|
||||
msg = "bar"
|
||||
raise ToolException(msg)
|
||||
|
||||
actual = foo2.invoke({"x": 0})
|
||||
expected = "bar"
|
||||
assert actual == expected
|
||||
|
||||
@tool(handle_tool_error=True)
|
||||
def foo3(x: int) -> int:
|
||||
"""X"""
|
||||
msg = "bar"
|
||||
raise ValueError(msg)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
foo3.invoke({"x": 0})
|
||||
|
||||
|
||||
def test_exception_handling_callable_type_hints() -> None:
|
||||
def handle(e: ValueError) -> str:
|
||||
return e.args[0]
|
||||
|
||||
@tool(handle_tool_error=handle)
|
||||
def foo(x: int) -> int:
|
||||
"""X"""
|
||||
msg = "bar"
|
||||
raise ValueError(msg)
|
||||
|
||||
actual = foo.invoke({"x": 0})
|
||||
expected = "bar"
|
||||
assert actual == expected
|
||||
|
||||
def handle2(e: Union[ToolException, KeyError]) -> str:
|
||||
return e.args[0]
|
||||
|
||||
@tool(handle_tool_error=handle2)
|
||||
def foo2(x: int) -> int:
|
||||
"""X"""
|
||||
msg = "bar"
|
||||
raise ValueError(msg)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
foo2.invoke({"x": 0})
|
||||
|
||||
|
||||
def test__infer_handled_types() -> None:
|
||||
def handle(e): # type: ignore
|
||||
return ""
|
||||
|
||||
def handle2(e: Exception) -> str:
|
||||
return ""
|
||||
|
||||
def handle3(e: Union[ValueError, ToolException]) -> str:
|
||||
return ""
|
||||
|
||||
class Handler:
|
||||
def handle(self, e: ValueError) -> str:
|
||||
return ""
|
||||
|
||||
handle4 = Handler().handle
|
||||
|
||||
expected: tuple = (ToolException,)
|
||||
actual = _infer_handled_types(handle)
|
||||
assert expected == actual
|
||||
|
||||
expected = (Exception,)
|
||||
actual = _infer_handled_types(handle2)
|
||||
assert expected == actual
|
||||
|
||||
expected = (ValueError, ToolException)
|
||||
actual = _infer_handled_types(handle3)
|
||||
assert expected == actual
|
||||
|
||||
expected = (ValueError,)
|
||||
actual = _infer_handled_types(handle4)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
async def test_async_exception_handling_non_tool_exception() -> None:
|
||||
_tool = _FakeExceptionTool(exception=ValueError())
|
||||
with pytest.raises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user