Compare commits

...

6 Commits

Author SHA1 Message Date
isaac hershenson
1865e737d0 catch exceptions properly 2024-09-11 11:59:14 -07:00
isaac hershenson
c50cd99d99 fmt 2024-09-11 11:29:06 -07:00
isaac hershenson
ee932c19d8 type errors 2024-09-11 10:01:37 -07:00
isaac hershenson
cd282e3386 remove print 2024-09-11 09:58:33 -07:00
isaac hershenson
5329b43bc1 auto infer toolexception type 2024-09-11 09:56:31 -07:00
isaac hershenson
ecd5a4b460 tool error handling 2024-09-11 09:43:56 -07:00
3 changed files with 112 additions and 76 deletions

View File

@@ -543,43 +543,54 @@ class ChildTool(BaseTool):
artifact = None
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except Exception as e:
raise ToolException(str(e)) from e
except KeyboardInterrupt as e:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
@@ -652,50 +663,61 @@ class ChildTool(BaseTool):
artifact = None
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except Exception as e:
raise ToolException(str(e)) from e
except KeyboardInterrupt as e:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
if error_to_raise:
await run_manager.on_tool_error(error_to_raise)

View File

@@ -2,9 +2,9 @@ import inspect
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError, create_model
from langchain_core.runnables import Runnable
from langchain_core.tools.base import BaseTool
from langchain_core.tools.base import BaseTool, ToolException
from langchain_core.tools.simple import Tool
from langchain_core.tools.structured import StructuredTool
@@ -17,6 +17,12 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False,
handle_validation_error: Optional[
Union[bool, str, Callable[[ValidationError], str]]
] = False,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@@ -41,6 +47,15 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
handle_tool_error: Handle the content of the ToolException thrown. If False
do nothing, If True returns 'Tool execution error'. If string then
returns that string directly, if callable converts error to string to be returned.
Defaults to False.
handle_validation_error: Handle the content of the ValidationError thrown. If False
do nothing, If True returns 'Tool input validation error'. If string then
returns that string directly, if callable converts error to string to be returned.
Defaults to False.
Returns:
The tool.
@@ -137,7 +152,7 @@ def tool(
monkey: The baz.
\"\"\"
return bar
"""
""" # noqa: E501
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
@@ -184,6 +199,8 @@ def tool(
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function

View File

@@ -679,7 +679,7 @@ def test_exception_handling_callable() -> None:
def test_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):
with pytest.raises(ToolException):
_tool.run({})
@@ -710,7 +710,7 @@ async def test_async_exception_handling_callable() -> None:
async def test_async_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):
with pytest.raises(ToolException):
await _tool.arun({})
@@ -806,7 +806,7 @@ def test_validation_error_handling_non_validation_error(
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
with pytest.raises(ToolException):
_tool.run({})
@@ -868,7 +868,7 @@ async def test_async_validation_error_handling_non_validation_error(
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
with pytest.raises(ToolException):
await _tool.arun({})
@@ -1408,10 +1408,8 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
assert tool_.invoke(
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
) == ToolMessage("bar", tool_call_id="123", name="foo")
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
with pytest.raises(expected_error):
with pytest.raises(ToolException):
tool_.invoke({"x": 5})
assert convert_to_openai_function(tool_) == {
@@ -1873,7 +1871,6 @@ def test__get_all_basemodel_annotations_v1() -> None:
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic import Field as FieldV2 # pydantic: ignore
from pydantic import ValidationError as ValidationErrorV2 # pydantic: ignore
class Foo(BaseModelV2):
x: List[int] = FieldV2(
@@ -1903,7 +1900,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
}
assert foo.invoke({"x": [0] * 10})
with pytest.raises(ValidationErrorV2):
with pytest.raises(ToolException):
foo.invoke({"x": [0] * 9})