feat(langchain): support PEP604 ( | union) in tool node error handlers (#32861)

This allows to use PEP604 syntax for `ToolNode` error handlers
```python
def error_handler(e: ValueError | ToolException) -> str:
    return "error"

ToolNode(my_tool, handle_tool_errors=error_handler).invoke(...)
```
Without this change, this fails with `AttributeError: 'types.UnionType'
object has no attribute '__mro__'`
This commit is contained in:
Christophe Bornet
2025-09-09 16:11:12 +02:00
committed by GitHub
parent cc3b5afe52
commit e36e25fe2f
3 changed files with 16 additions and 7 deletions

View File

@@ -40,6 +40,7 @@ import inspect
import json
from copy import copy, deepcopy
from dataclasses import replace
from types import UnionType
from typing import (
TYPE_CHECKING,
Annotated,
@@ -246,7 +247,7 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception],
type_hints = get_type_hints(handler)
if first_param.name in type_hints:
origin = get_origin(first_param.annotation)
if origin is Union:
if origin in [Union, UnionType]:
args = get_args(first_param.annotation)
if all(issubclass(arg, Exception) for arg in args):
return tuple(args)

View File

@@ -1,5 +1,6 @@
import dataclasses
import inspect
from types import UnionType
from typing import (
Annotated,
Union,
@@ -343,16 +344,19 @@ def test__infer_handled_types() -> None:
def handle2(e: Exception) -> str:
return ""
def handle3(e: Union[ValueError, ToolException]) -> str:
def handle3(e: ValueError | ToolException) -> str:
return ""
def handle4(e: Union[ValueError, ToolException]) -> str:
return ""
class Handler:
def handle(self, e: ValueError) -> str:
return ""
handle4 = Handler().handle
handle5 = Handler().handle
def handle5(e: Union[Union[TypeError, ValueError], ToolException]) -> str:
def handle6(e: Union[Union[TypeError, ValueError], ToolException]) -> str:
return ""
expected: tuple = (Exception,)
@@ -367,14 +371,18 @@ def test__infer_handled_types() -> None:
actual = _infer_handled_types(handle3)
assert expected == actual
expected = (ValueError,)
expected = (ValueError, ToolException)
actual = _infer_handled_types(handle4)
assert expected == actual
expected = (TypeError, ValueError, ToolException)
expected = (ValueError,)
actual = _infer_handled_types(handle5)
assert expected == actual
expected = (TypeError, ValueError, ToolException)
actual = _infer_handled_types(handle6)
assert expected == actual
with pytest.raises(ValueError):
def handler(e: str) -> str:

View File

@@ -272,7 +272,7 @@ def test_tool_node_error_handling_default_exception() -> None:
async def test_tool_node_error_handling() -> None:
def handle_all(e: Union[ValueError, ToolException, ToolInvocationError]):
def handle_all(e: ValueError | ToolException | ToolInvocationError):
return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
# test catching all exceptions, via: