From e36e25fe2fb7bc694d560220b9daf5c8ce110e54 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 9 Sep 2025 16:11:12 +0200 Subject: [PATCH] feat(langchain): support PEP604 ( `|` union) in tool node error handlers (#32861) This allows to use PEP604 syntax for `ToolNode` error handlers ```python def error_handler(e: ValueError | ToolException) -> str: return "error" ToolNode(my_tool, handle_tool_errors=error_handler).invoke(...) ``` Without this change, this fails with `AttributeError: 'types.UnionType' object has no attribute '__mro__'` --- .../langchain_v1/langchain/agents/tool_node.py | 3 ++- .../unit_tests/agents/test_react_agent.py | 18 +++++++++++++----- .../tests/unit_tests/agents/test_tool_node.py | 2 +- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/tool_node.py b/libs/langchain_v1/langchain/agents/tool_node.py index aea11721681..263867bf758 100644 --- a/libs/langchain_v1/langchain/agents/tool_node.py +++ b/libs/langchain_v1/langchain/agents/tool_node.py @@ -40,6 +40,7 @@ import inspect import json from copy import copy, deepcopy from dataclasses import replace +from types import UnionType from typing import ( TYPE_CHECKING, Annotated, @@ -246,7 +247,7 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], type_hints = get_type_hints(handler) if first_param.name in type_hints: origin = get_origin(first_param.annotation) - if origin is Union: + if origin in [Union, UnionType]: args = get_args(first_param.annotation) if all(issubclass(arg, Exception) for arg in args): return tuple(args) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py index d6a948f3185..1b2dd9eeb4f 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_react_agent.py @@ -1,5 +1,6 @@ import dataclasses import inspect +from types import UnionType from typing import ( Annotated, Union, @@ -343,16 +344,19 @@ def test__infer_handled_types() -> None: def handle2(e: Exception) -> str: return "" - def handle3(e: Union[ValueError, ToolException]) -> str: + def handle3(e: ValueError | ToolException) -> str: + return "" + + def handle4(e: Union[ValueError, ToolException]) -> str: return "" class Handler: def handle(self, e: ValueError) -> str: return "" - handle4 = Handler().handle + handle5 = Handler().handle - def handle5(e: Union[Union[TypeError, ValueError], ToolException]) -> str: + def handle6(e: Union[Union[TypeError, ValueError], ToolException]) -> str: return "" expected: tuple = (Exception,) @@ -367,14 +371,18 @@ def test__infer_handled_types() -> None: actual = _infer_handled_types(handle3) assert expected == actual - expected = (ValueError,) + expected = (ValueError, ToolException) actual = _infer_handled_types(handle4) assert expected == actual - expected = (TypeError, ValueError, ToolException) + expected = (ValueError,) actual = _infer_handled_types(handle5) assert expected == actual + expected = (TypeError, ValueError, ToolException) + actual = _infer_handled_types(handle6) + assert expected == actual + with pytest.raises(ValueError): def handler(e: str) -> str: diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py index 8581cf3c7cf..e843788d13a 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node.py @@ -272,7 +272,7 @@ def test_tool_node_error_handling_default_exception() -> None: async def test_tool_node_error_handling() -> None: - def handle_all(e: Union[ValueError, ToolException, ToolInvocationError]): + def handle_all(e: ValueError | ToolException | ToolInvocationError): return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e)) # test catching all exceptions, via: