Compare commits

...

3 Commits

Author SHA1 Message Date
isaac hershenson
5e401dad81 bagatur comment 2024-09-11 18:46:47 -07:00
Isaac Francisco
541d04000e Merge branch 'v0.3rc' into isaac/toolerrorhandling03 2024-09-11 18:42:43 -07:00
isaac hershenson
a67435a047 tool errors during execution as ToolExeptions 2024-09-11 12:13:54 -07:00
3 changed files with 105 additions and 75 deletions

View File

@@ -649,43 +649,50 @@ class ChildTool(BaseTool):
artifact = None
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
status = "error"
except ToolException as e:
raise e
except Exception as e:
raise ToolException(str(e)) from e
except KeyboardInterrupt as e:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
@@ -758,50 +765,57 @@ class ChildTool(BaseTool):
artifact = None
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
status = "error"
except ToolException as e:
raise e
except Exception as e:
raise ToolException(str(e)) from e
except KeyboardInterrupt as e:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
if error_to_raise:
await run_manager.on_tool_error(error_to_raise)

View File

@@ -1,11 +1,11 @@
import inspect
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, Field, ValidationError, create_model
from langchain_core.callbacks import Callbacks
from langchain_core.runnables import Runnable
from langchain_core.tools.base import BaseTool
from langchain_core.tools.base import BaseTool, ToolException
from langchain_core.tools.simple import Tool
from langchain_core.tools.structured import StructuredTool
@@ -18,6 +18,12 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False,
handle_validation_error: Optional[
Union[bool, str, Callable[[ValidationError], str]]
] = False,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@@ -42,6 +48,14 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
handle_tool_error: Handle the content of the ToolException thrown. If False
do nothing, If True returns 'Tool execution error'. If string then
returns that string directly, if callable converts error to string to be returned.
Defaults to False.
handle_validation_error: Handle the content of the ValidationError thrown. If False
do nothing, If True returns 'Tool input validation error'. If string then
returns that string directly, if callable converts error to string to be returned.
Defaults to False.
Returns:
The tool.
@@ -138,7 +152,7 @@ def tool(
monkey: The baz.
\"\"\"
return bar
"""
""" # noqa: E501
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
@@ -185,6 +199,8 @@ def tool(
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
@@ -200,6 +216,8 @@ def tool(
return_direct=return_direct,
coroutine=coroutine,
response_format=response_format,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
return _make_tool

View File

@@ -739,7 +739,7 @@ def test_exception_handling_callable() -> None:
def test_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):
with pytest.raises(ToolException):
_tool.run({})
@@ -770,7 +770,7 @@ async def test_async_exception_handling_callable() -> None:
async def test_async_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):
with pytest.raises(ToolException):
await _tool.arun({})
@@ -866,7 +866,7 @@ def test_validation_error_handling_non_validation_error(
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
with pytest.raises(ToolException):
_tool.run({})
@@ -928,7 +928,7 @@ async def test_async_validation_error_handling_non_validation_error(
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
with pytest.raises(ToolException):
await _tool.arun({})
@@ -1469,10 +1469,8 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
assert tool_.invoke(
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
) == ToolMessage("bar", tool_call_id="123", name="foo")
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
with pytest.raises(expected_error):
with pytest.raises(ToolException):
tool_.invoke({"x": 5})
assert convert_to_openai_function(tool_) == {
@@ -2051,5 +2049,5 @@ def test_structured_tool_direct_init() -> None:
tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo)
with pytest.raises(NotImplementedError):
with pytest.raises(ToolException):
assert tool.invoke("hello") == "hello"