mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
core[minor]: add validation error handler to BaseTool
(#14007)
- **Description:** add a ValidationError handler as a field of [`BaseTool`](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/tools.py#L101) and add unit tests for the code change. - **Issue:** #12721 #13662 - **Dependencies:** None - **Tag maintainer:** - **Twitter handle:** @hmdev3 - **NOTE:** - I'm wondering if the update of document is required. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
bdacfafa05
commit
cc17334473
@ -20,6 +20,7 @@ from langchain_core.pydantic_v1 import (
|
|||||||
BaseModel,
|
BaseModel,
|
||||||
Extra,
|
Extra,
|
||||||
Field,
|
Field,
|
||||||
|
ValidationError,
|
||||||
create_model,
|
create_model,
|
||||||
root_validator,
|
root_validator,
|
||||||
validate_arguments,
|
validate_arguments,
|
||||||
@ -169,6 +170,11 @@ class ChildTool(BaseTool):
|
|||||||
] = False
|
] = False
|
||||||
"""Handle the content of the ToolException thrown."""
|
"""Handle the content of the ToolException thrown."""
|
||||||
|
|
||||||
|
handle_validation_error: Optional[
|
||||||
|
Union[bool, str, Callable[[ValidationError], str]]
|
||||||
|
] = False
|
||||||
|
"""Handle the content of the ValidationError thrown."""
|
||||||
|
|
||||||
class Config(Serializable.Config):
|
class Config(Serializable.Config):
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -346,6 +352,21 @@ class ChildTool(BaseTool):
|
|||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._run(*tool_args, **tool_kwargs)
|
else self._run(*tool_args, **tool_kwargs)
|
||||||
)
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
if not self.handle_validation_error:
|
||||||
|
raise e
|
||||||
|
elif isinstance(self.handle_validation_error, bool):
|
||||||
|
observation = "Tool input validation error"
|
||||||
|
elif isinstance(self.handle_validation_error, str):
|
||||||
|
observation = self.handle_validation_error
|
||||||
|
elif callable(self.handle_validation_error):
|
||||||
|
observation = self.handle_validation_error(e)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||||
|
f"str or callable. Received: {self.handle_validation_error}"
|
||||||
|
)
|
||||||
|
return observation
|
||||||
except ToolException as e:
|
except ToolException as e:
|
||||||
if not self.handle_tool_error:
|
if not self.handle_tool_error:
|
||||||
run_manager.on_tool_error(e)
|
run_manager.on_tool_error(e)
|
||||||
@ -422,6 +443,21 @@ class ChildTool(BaseTool):
|
|||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else await self._arun(*tool_args, **tool_kwargs)
|
else await self._arun(*tool_args, **tool_kwargs)
|
||||||
)
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
if not self.handle_validation_error:
|
||||||
|
raise e
|
||||||
|
elif isinstance(self.handle_validation_error, bool):
|
||||||
|
observation = "Tool input validation error"
|
||||||
|
elif isinstance(self.handle_validation_error, str):
|
||||||
|
observation = self.handle_validation_error
|
||||||
|
elif callable(self.handle_validation_error):
|
||||||
|
observation = self.handle_validation_error(e)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||||
|
f"str or callable. Received: {self.handle_validation_error}"
|
||||||
|
)
|
||||||
|
return observation
|
||||||
except ToolException as e:
|
except ToolException as e:
|
||||||
if not self.handle_tool_error:
|
if not self.handle_tool_error:
|
||||||
await run_manager.on_tool_error(e)
|
await run_manager.on_tool_error(e)
|
||||||
|
@ -3,7 +3,7 @@ import json
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
from langchain_core.tools import (
|
from langchain_core.tools import (
|
||||||
BaseTool,
|
BaseTool,
|
||||||
SchemaAnnotationError,
|
SchemaAnnotationError,
|
||||||
@ -620,7 +620,10 @@ def test_exception_handling_str() -> None:
|
|||||||
|
|
||||||
def test_exception_handling_callable() -> None:
|
def test_exception_handling_callable() -> None:
|
||||||
expected = "foo bar"
|
expected = "foo bar"
|
||||||
handling = lambda _: expected # noqa: E731
|
|
||||||
|
def handling(e: ToolException) -> str:
|
||||||
|
return expected # noqa: E731
|
||||||
|
|
||||||
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
||||||
actual = _tool.run({})
|
actual = _tool.run({})
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
@ -648,7 +651,10 @@ async def test_async_exception_handling_str() -> None:
|
|||||||
|
|
||||||
async def test_async_exception_handling_callable() -> None:
|
async def test_async_exception_handling_callable() -> None:
|
||||||
expected = "foo bar"
|
expected = "foo bar"
|
||||||
handling = lambda _: expected # noqa: E731
|
|
||||||
|
def handling(e: ToolException) -> str:
|
||||||
|
return expected # noqa: E731
|
||||||
|
|
||||||
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
_tool = _FakeExceptionTool(handle_tool_error=handling)
|
||||||
actual = await _tool.arun({})
|
actual = await _tool.arun({})
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
@ -691,3 +697,127 @@ def test_structured_tool_from_function() -> None:
|
|||||||
prefix = "foo(bar: int, baz: str) -> str - "
|
prefix = "foo(bar: int, baz: str) -> str - "
|
||||||
assert foo.__doc__ is not None
|
assert foo.__doc__ is not None
|
||||||
assert structured_tool.description == prefix + foo.__doc__.strip()
|
assert structured_tool.description == prefix + foo.__doc__.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_error_handling_bool() -> None:
|
||||||
|
"""Test that validation errors are handled correctly."""
|
||||||
|
expected = "Tool input validation error"
|
||||||
|
_tool = _MockStructuredTool(handle_validation_error=True)
|
||||||
|
actual = _tool.run({})
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_error_handling_str() -> None:
|
||||||
|
"""Test that validation errors are handled correctly."""
|
||||||
|
expected = "foo bar"
|
||||||
|
_tool = _MockStructuredTool(handle_validation_error=expected)
|
||||||
|
actual = _tool.run({})
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_error_handling_callable() -> None:
|
||||||
|
"""Test that validation errors are handled correctly."""
|
||||||
|
expected = "foo bar"
|
||||||
|
|
||||||
|
def handling(e: ValidationError) -> str:
|
||||||
|
return expected # noqa: E731
|
||||||
|
|
||||||
|
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||||
|
actual = _tool.run({})
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"handler",
|
||||||
|
[
|
||||||
|
True,
|
||||||
|
"foo bar",
|
||||||
|
lambda _: "foo bar",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_validation_error_handling_non_validation_error(
|
||||||
|
handler: Union[bool, str, Callable[[ValidationError], str]]
|
||||||
|
) -> None:
|
||||||
|
"""Test that validation errors are handled correctly."""
|
||||||
|
|
||||||
|
class _RaiseNonValidationErrorTool(BaseTool):
|
||||||
|
name = "raise_non_validation_error_tool"
|
||||||
|
description = "A tool that raises a non-validation error"
|
||||||
|
|
||||||
|
def _parse_input(
|
||||||
|
self,
|
||||||
|
tool_input: Union[str, Dict],
|
||||||
|
) -> Union[str, Dict[str, Any]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _run(self) -> str:
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
async def _arun(self) -> str:
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
_tool.run({})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_validation_error_handling_bool() -> None:
|
||||||
|
"""Test that validation errors are handled correctly."""
|
||||||
|
expected = "Tool input validation error"
|
||||||
|
_tool = _MockStructuredTool(handle_validation_error=True)
|
||||||
|
actual = await _tool.arun({})
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_validation_error_handling_str() -> None:
|
||||||
|
"""Test that validation errors are handled correctly."""
|
||||||
|
expected = "foo bar"
|
||||||
|
_tool = _MockStructuredTool(handle_validation_error=expected)
|
||||||
|
actual = await _tool.arun({})
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_validation_error_handling_callable() -> None:
|
||||||
|
"""Test that validation errors are handled correctly."""
|
||||||
|
expected = "foo bar"
|
||||||
|
|
||||||
|
def handling(e: ValidationError) -> str:
|
||||||
|
return expected # noqa: E731
|
||||||
|
|
||||||
|
_tool = _MockStructuredTool(handle_validation_error=handling)
|
||||||
|
actual = await _tool.arun({})
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"handler",
|
||||||
|
[
|
||||||
|
True,
|
||||||
|
"foo bar",
|
||||||
|
lambda _: "foo bar",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_async_validation_error_handling_non_validation_error(
|
||||||
|
handler: Union[bool, str, Callable[[ValidationError], str]]
|
||||||
|
) -> None:
|
||||||
|
"""Test that validation errors are handled correctly."""
|
||||||
|
|
||||||
|
class _RaiseNonValidationErrorTool(BaseTool):
|
||||||
|
name = "raise_non_validation_error_tool"
|
||||||
|
description = "A tool that raises a non-validation error"
|
||||||
|
|
||||||
|
def _parse_input(
|
||||||
|
self,
|
||||||
|
tool_input: Union[str, Dict],
|
||||||
|
) -> Union[str, Dict[str, Any]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _run(self) -> str:
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
async def _arun(self) -> str:
|
||||||
|
return "dummy"
|
||||||
|
|
||||||
|
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await _tool.arun({})
|
||||||
|
Loading…
Reference in New Issue
Block a user