mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-12 14:23:58 +00:00
core[patch]: support handle_tool_error=(Exception, ...) tuple
This commit is contained in:
parent
0640cbf2f1
commit
09e330064a
@ -19,8 +19,6 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
get_args,
|
|
||||||
get_origin,
|
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,6 +35,7 @@ from pydantic import (
|
|||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||||
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
from pydantic.v1 import validate_arguments as validate_arguments_v1
|
||||||
|
from typing_extensions import get_args, get_origin
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -333,13 +332,13 @@ class BaseTool(RunnableSerializable[Union[str, dict, ToolCall], Any]):
|
|||||||
typehint_mandate = """
|
typehint_mandate = """
|
||||||
class ChildTool(BaseTool):
|
class ChildTool(BaseTool):
|
||||||
...
|
...
|
||||||
args_schema: Type[BaseModel] = SchemaClass
|
args_schema: type[BaseModel] = SchemaClass
|
||||||
..."""
|
..."""
|
||||||
name = cls.__name__
|
name = cls.__name__
|
||||||
msg = (
|
msg = (
|
||||||
f"Tool definition for {name} must include valid type annotations"
|
f"Tool definition for {name} must include valid type annotations"
|
||||||
f" for argument 'args_schema' to behave as expected.\n"
|
f" for argument 'args_schema' to behave as expected.\n"
|
||||||
f"Expected annotation of 'Type[BaseModel]'"
|
f"Expected annotation of 'type[BaseModel]'"
|
||||||
f" but got '{args_schema_type}'.\n"
|
f" but got '{args_schema_type}'.\n"
|
||||||
f"Expected class looks like:\n"
|
f"Expected class looks like:\n"
|
||||||
f"{typehint_mandate}"
|
f"{typehint_mandate}"
|
||||||
@ -399,9 +398,9 @@ class ChildTool(BaseTool):
|
|||||||
You can use these to eg identify a specific instance of a tool with its use case.
|
You can use these to eg identify a specific instance of a tool with its use case.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
handle_tool_error: Optional[Union[bool, str, Callable[[ToolException], str]]] = (
|
handle_tool_error: Union[
|
||||||
False
|
None, bool, str, Callable[..., str], tuple[type[Exception], ...]
|
||||||
)
|
] = False
|
||||||
"""Handle the content of the ToolException thrown."""
|
"""Handle the content of the ToolException thrown."""
|
||||||
|
|
||||||
handle_validation_error: Optional[
|
handle_validation_error: Optional[
|
||||||
@ -668,21 +667,33 @@ class ChildTool(BaseTool):
|
|||||||
else:
|
else:
|
||||||
content = response
|
content = response
|
||||||
status = "success"
|
status = "success"
|
||||||
except (ValidationError, ValidationErrorV1) as e:
|
|
||||||
if not self.handle_validation_error:
|
|
||||||
error_to_raise = e
|
|
||||||
else:
|
|
||||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
|
||||||
status = "error"
|
|
||||||
except ToolException as e:
|
|
||||||
if not self.handle_tool_error:
|
|
||||||
error_to_raise = e
|
|
||||||
else:
|
|
||||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
|
||||||
status = "error"
|
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
error_to_raise = e
|
|
||||||
status = "error"
|
status = "error"
|
||||||
|
# Validation error (args don't match pydantic schema)
|
||||||
|
if isinstance(e, (ValidationError, ValidationErrorV1)):
|
||||||
|
# Unhandled
|
||||||
|
if not self.handle_validation_error:
|
||||||
|
error_to_raise = e
|
||||||
|
# Handled
|
||||||
|
else:
|
||||||
|
content = _handle_validation_error(
|
||||||
|
e, flag=self.handle_validation_error
|
||||||
|
)
|
||||||
|
# Tool error (error raised at tool runtime)
|
||||||
|
else:
|
||||||
|
if isinstance(self.handle_tool_error, tuple):
|
||||||
|
handled_types: tuple = self.handle_tool_error
|
||||||
|
elif callable(self.handle_tool_error):
|
||||||
|
handled_types = _infer_handled_types(self.handle_tool_error)
|
||||||
|
else:
|
||||||
|
handled_types = (ToolException,)
|
||||||
|
|
||||||
|
# Unhandled
|
||||||
|
if not self.handle_tool_error or not isinstance(e, handled_types):
|
||||||
|
error_to_raise = e
|
||||||
|
# Handled
|
||||||
|
else:
|
||||||
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
|
|
||||||
if error_to_raise:
|
if error_to_raise:
|
||||||
run_manager.on_tool_error(error_to_raise)
|
run_manager.on_tool_error(error_to_raise)
|
||||||
@ -785,21 +796,32 @@ class ChildTool(BaseTool):
|
|||||||
else:
|
else:
|
||||||
content = response
|
content = response
|
||||||
status = "success"
|
status = "success"
|
||||||
except ValidationError as e:
|
|
||||||
if not self.handle_validation_error:
|
|
||||||
error_to_raise = e
|
|
||||||
else:
|
|
||||||
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
|
||||||
status = "error"
|
|
||||||
except ToolException as e:
|
|
||||||
if not self.handle_tool_error:
|
|
||||||
error_to_raise = e
|
|
||||||
else:
|
|
||||||
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
|
||||||
status = "error"
|
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
error_to_raise = e
|
|
||||||
status = "error"
|
status = "error"
|
||||||
|
# Validation error (args don't match pydantic schema)
|
||||||
|
if isinstance(e, (ValidationError, ValidationErrorV1)):
|
||||||
|
# Unhandled
|
||||||
|
if not self.handle_validation_error:
|
||||||
|
error_to_raise = e
|
||||||
|
# Handled
|
||||||
|
else:
|
||||||
|
content = _handle_validation_error(
|
||||||
|
e, flag=self.handle_validation_error
|
||||||
|
)
|
||||||
|
# Tool error (error raised at tool runtime)
|
||||||
|
else:
|
||||||
|
if isinstance(self.handle_tool_error, tuple):
|
||||||
|
handled_types: tuple = self.handle_tool_error
|
||||||
|
elif callable(self.handle_tool_error):
|
||||||
|
handled_types = _infer_handled_types(self.handle_tool_error)
|
||||||
|
else:
|
||||||
|
handled_types = (ToolException,)
|
||||||
|
# Unhandled
|
||||||
|
if not self.handle_tool_error or not isinstance(e, handled_types):
|
||||||
|
error_to_raise = e
|
||||||
|
# Handled
|
||||||
|
else:
|
||||||
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
|
|
||||||
if error_to_raise:
|
if error_to_raise:
|
||||||
await run_manager.on_tool_error(error_to_raise)
|
await run_manager.on_tool_error(error_to_raise)
|
||||||
@ -842,11 +864,17 @@ def _handle_validation_error(
|
|||||||
|
|
||||||
|
|
||||||
def _handle_tool_error(
|
def _handle_tool_error(
|
||||||
e: ToolException,
|
e: Exception,
|
||||||
*,
|
*,
|
||||||
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
flag: Union[
|
||||||
|
None,
|
||||||
|
bool,
|
||||||
|
str,
|
||||||
|
Callable[..., str],
|
||||||
|
tuple[type[Exception], ...],
|
||||||
|
],
|
||||||
) -> str:
|
) -> str:
|
||||||
if isinstance(flag, bool):
|
if isinstance(flag, (bool, tuple)):
|
||||||
content = e.args[0] if e.args else "Tool execution error"
|
content = e.args[0] if e.args else "Tool execution error"
|
||||||
elif isinstance(flag, str):
|
elif isinstance(flag, str):
|
||||||
content = flag
|
content = flag
|
||||||
@ -854,7 +882,7 @@ def _handle_tool_error(
|
|||||||
content = flag(e)
|
content = flag(e)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
f"Got unexpected type of `handle_tool_error`. Expected bool, str, tuple, "
|
||||||
f"or callable. Received: {flag}"
|
f"or callable. Received: {flag}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@ -1050,3 +1078,24 @@ class BaseToolkit(BaseModel, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_tools(self) -> list[BaseTool]:
|
def get_tools(self) -> list[BaseTool]:
|
||||||
"""Get the tools in the toolkit."""
|
"""Get the tools in the toolkit."""
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception]]:
|
||||||
|
sig = inspect.signature(handler)
|
||||||
|
params = list(sig.parameters.values())
|
||||||
|
if params:
|
||||||
|
# If it's a method, the first argument is typically 'self' or 'cls'
|
||||||
|
if params[0].name in ["self", "cls"] and len(params) == 2:
|
||||||
|
first_param = params[1]
|
||||||
|
else:
|
||||||
|
first_param = params[0]
|
||||||
|
|
||||||
|
type_hints = get_type_hints(handler)
|
||||||
|
if first_param.name in type_hints:
|
||||||
|
if get_origin(first_param.annotation) is Union:
|
||||||
|
return tuple(get_args(first_param.annotation))
|
||||||
|
return (type_hints[first_param.name],)
|
||||||
|
|
||||||
|
# If no type information is available, return (ToolException,) for backwards
|
||||||
|
# compatibility.
|
||||||
|
return (ToolException,)
|
||||||
|
@ -18,6 +18,7 @@ def tool(
|
|||||||
response_format: Literal["content", "content_and_artifact"] = "content",
|
response_format: Literal["content", "content_and_artifact"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
error_on_invalid_docstring: bool = True,
|
error_on_invalid_docstring: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Make tools out of functions, can be used with or without arguments.
|
"""Make tools out of functions, can be used with or without arguments.
|
||||||
|
|
||||||
@ -42,6 +43,7 @@ def tool(
|
|||||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
||||||
whether to raise ValueError on invalid Google Style docstrings.
|
whether to raise ValueError on invalid Google Style docstrings.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
kwargs: Additional keyword arguments to pass to BaseTool constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The tool.
|
The tool.
|
||||||
@ -186,6 +188,7 @@ def tool(
|
|||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
parse_docstring=parse_docstring,
|
parse_docstring=parse_docstring,
|
||||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# If someone doesn't want a schema applied, we must treat it as
|
# If someone doesn't want a schema applied, we must treat it as
|
||||||
# a simple string->string function
|
# a simple string->string function
|
||||||
@ -202,6 +205,7 @@ def tool(
|
|||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
coroutine=coroutine,
|
coroutine=coroutine,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _make_tool
|
return _make_tool
|
||||||
|
@ -48,6 +48,7 @@ from langchain_core.tools.base import (
|
|||||||
InjectedToolArg,
|
InjectedToolArg,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
_get_all_basemodel_annotations,
|
_get_all_basemodel_annotations,
|
||||||
|
_infer_handled_types,
|
||||||
_is_message_content_block,
|
_is_message_content_block,
|
||||||
_is_message_content_type,
|
_is_message_content_type,
|
||||||
)
|
)
|
||||||
@ -766,6 +767,97 @@ async def test_async_exception_handling_callable() -> None:
|
|||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_exception_handling_tuple() -> None:
|
||||||
|
@tool(handle_tool_error=(ValueError,))
|
||||||
|
def foo(x: int) -> int:
|
||||||
|
"""X"""
|
||||||
|
msg = "bar"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
actual = foo.invoke({"x": 0})
|
||||||
|
expected = "bar"
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
@tool(handle_tool_error=True)
|
||||||
|
def foo2(x: int) -> int:
|
||||||
|
"""X"""
|
||||||
|
msg = "bar"
|
||||||
|
raise ToolException(msg)
|
||||||
|
|
||||||
|
actual = foo2.invoke({"x": 0})
|
||||||
|
expected = "bar"
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
@tool(handle_tool_error=True)
|
||||||
|
def foo3(x: int) -> int:
|
||||||
|
"""X"""
|
||||||
|
msg = "bar"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
foo3.invoke({"x": 0})
|
||||||
|
|
||||||
|
|
||||||
|
def test_exception_handling_callable_type_hints() -> None:
|
||||||
|
def handle(e: ValueError) -> str:
|
||||||
|
return e.args[0]
|
||||||
|
|
||||||
|
@tool(handle_tool_error=handle)
|
||||||
|
def foo(x: int) -> int:
|
||||||
|
"""X"""
|
||||||
|
msg = "bar"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
actual = foo.invoke({"x": 0})
|
||||||
|
expected = "bar"
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
def handle2(e: Union[ToolException, KeyError]) -> str:
|
||||||
|
return e.args[0]
|
||||||
|
|
||||||
|
@tool(handle_tool_error=handle2)
|
||||||
|
def foo2(x: int) -> int:
|
||||||
|
"""X"""
|
||||||
|
msg = "bar"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
foo2.invoke({"x": 0})
|
||||||
|
|
||||||
|
|
||||||
|
def test__infer_handled_types() -> None:
|
||||||
|
def handle(e): # type: ignore
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def handle2(e: Exception) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def handle3(e: Union[ValueError, ToolException]) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
class Handler:
|
||||||
|
def handle(self, e: ValueError) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
handle4 = Handler().handle
|
||||||
|
|
||||||
|
expected: tuple = (ToolException,)
|
||||||
|
actual = _infer_handled_types(handle)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
expected = (Exception,)
|
||||||
|
actual = _infer_handled_types(handle2)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
expected = (ValueError, ToolException)
|
||||||
|
actual = _infer_handled_types(handle3)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
expected = (ValueError,)
|
||||||
|
actual = _infer_handled_types(handle4)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
async def test_async_exception_handling_non_tool_exception() -> None:
|
async def test_async_exception_handling_non_tool_exception() -> None:
|
||||||
_tool = _FakeExceptionTool(exception=ValueError())
|
_tool = _FakeExceptionTool(exception=ValueError())
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
|
Loading…
Reference in New Issue
Block a user