core[patch]: support ValidationError from pydantic v1 in tools (#27194)

This commit is contained in:
Vadym Barda 2024-10-08 10:19:04 -04:00 committed by GitHub
parent 16f5fdb38b
commit 8d27325dbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 8 deletions

View File

@ -35,6 +35,7 @@ from pydantic import (
validate_arguments, validate_arguments,
) )
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from pydantic.v1 import validate_arguments as validate_arguments_v1 from pydantic.v1 import validate_arguments as validate_arguments_v1
from langchain_core._api import deprecated from langchain_core._api import deprecated
@ -404,7 +405,7 @@ class ChildTool(BaseTool):
"""Handle the content of the ToolException thrown.""" """Handle the content of the ToolException thrown."""
handle_validation_error: Optional[ handle_validation_error: Optional[
Union[bool, str, Callable[[ValidationError], str]] Union[bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]]
] = False ] = False
"""Handle the content of the ValidationError thrown.""" """Handle the content of the ValidationError thrown."""
@ -667,7 +668,7 @@ class ChildTool(BaseTool):
else: else:
content = response content = response
status = "success" status = "success"
except ValidationError as e: except (ValidationError, ValidationErrorV1) as e:
if not self.handle_validation_error: if not self.handle_validation_error:
error_to_raise = e error_to_raise = e
else: else:
@ -819,9 +820,11 @@ def _is_tool_call(x: Any) -> bool:
def _handle_validation_error( def _handle_validation_error(
e: ValidationError, e: Union[ValidationError, ValidationErrorV1],
*, *,
flag: Union[Literal[True], str, Callable[[ValidationError], str]], flag: Union[
Literal[True], str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> str: ) -> str:
if isinstance(flag, bool): if isinstance(flag, bool):
content = "Tool input validation error" content = "Tool input validation error"

View File

@ -22,6 +22,7 @@ from typing import (
import pytest import pytest
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain_core import tools from langchain_core import tools
@ -825,7 +826,7 @@ def test_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly.""" """Test that validation errors are handled correctly."""
expected = "foo bar" expected = "foo bar"
def handling(e: ValidationError) -> str: def handling(e: Union[ValidationError, ValidationErrorV1]) -> str:
return expected return expected
_tool = _MockStructuredTool(handle_validation_error=handling) _tool = _MockStructuredTool(handle_validation_error=handling)
@ -842,7 +843,9 @@ def test_validation_error_handling_callable() -> None:
], ],
) )
def test_validation_error_handling_non_validation_error( def test_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]], handler: Union[
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> None: ) -> None:
"""Test that validation errors are handled correctly.""" """Test that validation errors are handled correctly."""
@ -887,7 +890,7 @@ async def test_async_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly.""" """Test that validation errors are handled correctly."""
expected = "foo bar" expected = "foo bar"
def handling(e: ValidationError) -> str: def handling(e: Union[ValidationError, ValidationErrorV1]) -> str:
return expected return expected
_tool = _MockStructuredTool(handle_validation_error=handling) _tool = _MockStructuredTool(handle_validation_error=handling)
@ -904,7 +907,9 @@ async def test_async_validation_error_handling_callable() -> None:
], ],
) )
async def test_async_validation_error_handling_non_validation_error( async def test_async_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]], handler: Union[
bool, str, Callable[[Union[ValidationError, ValidationErrorV1]], str]
],
) -> None: ) -> None:
"""Test that validation errors are handled correctly.""" """Test that validation errors are handled correctly."""