mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 14:05:37 +00:00
feat(langchain): support PEP604 ( |
union) in tool node error handlers (#32861)
This allows to use PEP604 syntax for `ToolNode` error handlers ```python def error_handler(e: ValueError | ToolException) -> str: return "error" ToolNode(my_tool, handle_tool_errors=error_handler).invoke(...) ``` Without this change, this fails with `AttributeError: 'types.UnionType' object has no attribute '__mro__'`
This commit is contained in:
committed by
GitHub
parent
cc3b5afe52
commit
e36e25fe2f
@@ -40,6 +40,7 @@ import inspect
|
||||
import json
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import replace
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -246,7 +247,7 @@ def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception],
|
||||
type_hints = get_type_hints(handler)
|
||||
if first_param.name in type_hints:
|
||||
origin = get_origin(first_param.annotation)
|
||||
if origin is Union:
|
||||
if origin in [Union, UnionType]:
|
||||
args = get_args(first_param.annotation)
|
||||
if all(issubclass(arg, Exception) for arg in args):
|
||||
return tuple(args)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Union,
|
||||
@@ -343,16 +344,19 @@ def test__infer_handled_types() -> None:
|
||||
def handle2(e: Exception) -> str:
|
||||
return ""
|
||||
|
||||
def handle3(e: Union[ValueError, ToolException]) -> str:
|
||||
def handle3(e: ValueError | ToolException) -> str:
|
||||
return ""
|
||||
|
||||
def handle4(e: Union[ValueError, ToolException]) -> str:
|
||||
return ""
|
||||
|
||||
class Handler:
|
||||
def handle(self, e: ValueError) -> str:
|
||||
return ""
|
||||
|
||||
handle4 = Handler().handle
|
||||
handle5 = Handler().handle
|
||||
|
||||
def handle5(e: Union[Union[TypeError, ValueError], ToolException]) -> str:
|
||||
def handle6(e: Union[Union[TypeError, ValueError], ToolException]) -> str:
|
||||
return ""
|
||||
|
||||
expected: tuple = (Exception,)
|
||||
@@ -367,14 +371,18 @@ def test__infer_handled_types() -> None:
|
||||
actual = _infer_handled_types(handle3)
|
||||
assert expected == actual
|
||||
|
||||
expected = (ValueError,)
|
||||
expected = (ValueError, ToolException)
|
||||
actual = _infer_handled_types(handle4)
|
||||
assert expected == actual
|
||||
|
||||
expected = (TypeError, ValueError, ToolException)
|
||||
expected = (ValueError,)
|
||||
actual = _infer_handled_types(handle5)
|
||||
assert expected == actual
|
||||
|
||||
expected = (TypeError, ValueError, ToolException)
|
||||
actual = _infer_handled_types(handle6)
|
||||
assert expected == actual
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
def handler(e: str) -> str:
|
||||
|
@@ -272,7 +272,7 @@ def test_tool_node_error_handling_default_exception() -> None:
|
||||
|
||||
|
||||
async def test_tool_node_error_handling() -> None:
|
||||
def handle_all(e: Union[ValueError, ToolException, ToolInvocationError]):
|
||||
def handle_all(e: ValueError | ToolException | ToolInvocationError):
|
||||
return TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
|
||||
|
||||
# test catching all exceptions, via:
|
||||
|
Reference in New Issue
Block a user