diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 88c787558ff..4de0452020f 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -35,6 +35,7 @@ from pydantic import ( validate_arguments, ) from pydantic.v1 import BaseModel as BaseModelV1 +from pydantic.v1 import ValidationError as ValidationErrorV1 from pydantic.v1 import validate_arguments as validate_arguments_v1 from langchain_core._api import deprecated @@ -404,7 +405,7 @@ class ChildTool(BaseTool): """Handle the content of the ToolException thrown.""" handle_validation_error: Optional[ - Union[bool, str, Callable[[ValidationError], str]] + Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]] ] = False """Handle the content of the ValidationError thrown.""" @@ -667,7 +668,7 @@ class ChildTool(BaseTool): else: content = response status = "success" - except ValidationError as e: + except (ValidationError, ValidationErrorV1) as e: if not self.handle_validation_error: error_to_raise = e else: @@ -819,9 +820,11 @@ def _is_tool_call(x: Any) -> bool: def _handle_validation_error( - e: ValidationError, + e: Union[ValidationError, ValidationErrorV1], *, - flag: Union[Literal[True], str, Callable[[ValidationError], str]], + flag: Union[ + Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str] + ], ) -> str: if isinstance(flag, bool): content = "Tool input validation error" diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index f77b473d39f..065e36c8668 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -22,6 +22,7 @@ from typing import ( import pytest from pydantic import BaseModel, Field, ValidationError from pydantic.v1 import BaseModel as BaseModelV1 +from pydantic.v1 import ValidationError as ValidationErrorV1 from typing_extensions import TypedDict from langchain_core import tools @@ -825,7 +826,7 @@ def test_validation_error_handling_callable() -> None: """Test that validation errors are handled correctly.""" expected = "foo bar" - def handling(e: ValidationError) -> str: + def handling(e: Union[ValidationError, ValidationErrorV1]) -> str: return expected _tool = _MockStructuredTool(handle_validation_error=handling) @@ -842,7 +843,9 @@ def test_validation_error_handling_callable() -> None: ], ) def test_validation_error_handling_non_validation_error( - handler: Union[bool, str, Callable[[ValidationError], str]], + handler: Union[ + bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str] + ], ) -> None: """Test that validation errors are handled correctly.""" @@ -887,7 +890,7 @@ async def test_async_validation_error_handling_callable() -> None: """Test that validation errors are handled correctly.""" expected = "foo bar" - def handling(e: ValidationError) -> str: + def handling(e: Union[ValidationError, ValidationErrorV1]) -> str: return expected _tool = _MockStructuredTool(handle_validation_error=handling) @@ -904,7 +907,9 @@ async def test_async_validation_error_handling_callable() -> None: ], ) async def test_async_validation_error_handling_non_validation_error( - handler: Union[bool, str, Callable[[ValidationError], str]], + handler: Union[ + bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str] + ], ) -> None: """Test that validation errors are handled correctly."""