mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-29 21:30:18 +00:00
core[minor]: add validation error handler to BaseTool (#14007)
- **Description:** add a ValidationError handler as a field of [`BaseTool`](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/tools.py#L101) and add unit tests for the code change. - **Issue:** #12721 #13662 - **Dependencies:** None - **Tag maintainer:** - **Twitter handle:** @hmdev3 - **NOTE:** - I'm wondering if the update of document is required. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -11,7 +11,7 @@ from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
SchemaAnnotationError,
|
||||
@@ -620,7 +620,10 @@ def test_exception_handling_str() -> None:
|
||||
|
||||
def test_exception_handling_callable() -> None:
|
||||
expected = "foo bar"
|
||||
handling = lambda _: expected # noqa: E731
|
||||
|
||||
def handling(e: ToolException) -> str:
|
||||
return expected # noqa: E731
|
||||
|
||||
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
@@ -648,7 +651,10 @@ async def test_async_exception_handling_str() -> None:
|
||||
|
||||
async def test_async_exception_handling_callable() -> None:
|
||||
expected = "foo bar"
|
||||
handling = lambda _: expected # noqa: E731
|
||||
|
||||
def handling(e: ToolException) -> str:
|
||||
return expected # noqa: E731
|
||||
|
||||
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
@@ -691,3 +697,127 @@ def test_structured_tool_from_function() -> None:
|
||||
prefix = "foo(bar: int, baz: str) -> str - "
|
||||
assert foo.__doc__ is not None
|
||||
assert structured_tool.description == prefix + foo.__doc__.strip()
|
||||
|
||||
|
||||
def test_validation_error_handling_bool() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "Tool input validation error"
|
||||
_tool = _MockStructuredTool(handle_validation_error=True)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_validation_error_handling_str() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
_tool = _MockStructuredTool(handle_validation_error=expected)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_validation_error_handling_callable() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
|
||||
def handling(e: ValidationError) -> str:
|
||||
return expected # noqa: E731
|
||||
|
||||
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||
actual = _tool.run({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"handler",
|
||||
[
|
||||
True,
|
||||
"foo bar",
|
||||
lambda _: "foo bar",
|
||||
],
|
||||
)
|
||||
def test_validation_error_handling_non_validation_error(
|
||||
handler: Union[bool, str, Callable[[ValidationError], str]]
|
||||
) -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
|
||||
class _RaiseNonValidationErrorTool(BaseTool):
|
||||
name = "raise_non_validation_error_tool"
|
||||
description = "A tool that raises a non-validation error"
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _run(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
async def _arun(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
||||
with pytest.raises(NotImplementedError):
|
||||
_tool.run({})
|
||||
|
||||
|
||||
async def test_async_validation_error_handling_bool() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "Tool input validation error"
|
||||
_tool = _MockStructuredTool(handle_validation_error=True)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
async def test_async_validation_error_handling_str() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
_tool = _MockStructuredTool(handle_validation_error=expected)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
async def test_async_validation_error_handling_callable() -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
expected = "foo bar"
|
||||
|
||||
def handling(e: ValidationError) -> str:
|
||||
return expected # noqa: E731
|
||||
|
||||
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||
actual = await _tool.arun({})
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"handler",
|
||||
[
|
||||
True,
|
||||
"foo bar",
|
||||
lambda _: "foo bar",
|
||||
],
|
||||
)
|
||||
async def test_async_validation_error_handling_non_validation_error(
|
||||
handler: Union[bool, str, Callable[[ValidationError], str]]
|
||||
) -> None:
|
||||
"""Test that validation errors are handled correctly."""
|
||||
|
||||
class _RaiseNonValidationErrorTool(BaseTool):
|
||||
name = "raise_non_validation_error_tool"
|
||||
description = "A tool that raises a non-validation error"
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _run(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
async def _arun(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
||||
with pytest.raises(NotImplementedError):
|
||||
await _tool.arun({})
|
||||
|
||||
Reference in New Issue
Block a user