Compare commits

...

3 Commits

Author SHA1 Message Date
isaac hershenson
93580fb5ec repr instead of str 2024-09-18 16:37:05 -07:00
isaac hershenson
f9cd109280 delete _get_filtered_args 2024-09-18 15:06:06 -07:00
isaac hershenson
765350b630 error handling in tool decorator 2024-09-16 08:44:55 -07:00
3 changed files with 150 additions and 96 deletions

View File

@@ -90,26 +90,6 @@ def _get_annotation_description(arg_type: Type) -> str | None:
return annotation
return None
def _get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
*,
filter_args: Sequence[str],
include_injected: bool = True,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.model_json_schema()["properties"]
valid_keys = signature(func).parameters
return {
k: schema[k]
for i, (k, param) in enumerate(valid_keys.items())
if k not in filter_args
and (i > 0 or param.name not in ("self", "cls"))
and (include_injected or not _is_injected_arg_type(param.annotation))
}
def _parse_python_function_docstring(
function: Callable, annotations: dict, error_on_invalid_docstring: bool = False
) -> Tuple[str, dict]:
@@ -649,43 +629,50 @@ class ChildTool(BaseTool):
artifact = None
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
if signature(self._run).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
if config_param := _get_runnable_config_param(self._run):
tool_kwargs[config_param] = config
response = context.run(self._run, *tool_args, **tool_kwargs)
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
status = "error"
except ToolException as e:
raise e
except Exception as e:
raise ToolException(repr(e)) from e
except KeyboardInterrupt as e:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
status = "error"
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
if error_to_raise:
run_manager.on_tool_error(error_to_raise)
@@ -758,50 +745,54 @@ class ChildTool(BaseTool):
artifact = None
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
child_config = patch_config(config, callbacks=run_manager.get_child())
context = copy_context()
context.run(_set_config_context, child_config)
func_to_check = (
self._run if self.__class__._arun is BaseTool._arun else self._arun
)
if signature(func_to_check).parameters.get("run_manager"):
tool_kwargs["run_manager"] = run_manager
if config_param := _get_runnable_config_param(func_to_check):
tool_kwargs[config_param] = config
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
coro = context.run(self._arun, *tool_args, **tool_kwargs)
if asyncio_accepts_context():
response = await asyncio.create_task(coro, context=context) # type: ignore
else:
response = await coro
if self.response_format == "content_and_artifact":
if not isinstance(response, tuple) or len(response) != 2:
raise ValueError(
"Since response_format='content_and_artifact' "
"a two-tuple of the message content and raw tool output is "
f"expected. Instead generated response of type: "
f"{type(response)}."
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(
e, flag=self.handle_validation_error
)
content, artifact = response
else:
content = response
status = "success"
except ValidationError as e:
if not self.handle_validation_error:
error_to_raise = e
else:
content = _handle_validation_error(e, flag=self.handle_validation_error)
status = "error"
status = "error"
except ToolException as e:
raise e
except Exception as e:
raise ToolException(repr(e)) from e
except ToolException as e:
if not self.handle_tool_error:
error_to_raise = e
else:
content = _handle_tool_error(e, flag=self.handle_tool_error)
status = "error"
except (Exception, KeyboardInterrupt) as e:
error_to_raise = e
status = "error"
if error_to_raise:
await run_manager.on_tool_error(error_to_raise)

View File

@@ -1,11 +1,11 @@
import inspect
from typing import Any, Callable, Dict, Literal, Optional, Type, Union, get_type_hints
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel, Field, ValidationError, create_model
from langchain_core.callbacks import Callbacks
from langchain_core.runnables import Runnable
from langchain_core.tools.base import BaseTool
from langchain_core.tools.base import BaseTool, ToolException
from langchain_core.tools.simple import Tool
from langchain_core.tools.structured import StructuredTool
@@ -18,6 +18,12 @@ def tool(
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
error_on_invalid_docstring: bool = True,
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False,
handle_validation_error: Optional[
Union[bool, str, Callable[[ValidationError], str]]
] = False,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@@ -42,6 +48,14 @@ def tool(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to True.
handle_tool_error: Handle the content of the ToolException thrown. If False
do nothing, If True returns 'Tool execution error'. If string then
returns that string directly, if callable converts error to string to be returned.
Defaults to False.
handle_validation_error: Handle the content of the ValidationError thrown. If False
do nothing, If True returns 'Tool input validation error'. If string then
returns that string directly, if callable converts error to string to be returned.
Defaults to False.
Returns:
The tool.
@@ -138,7 +152,7 @@ def tool(
monkey: The baz.
\"\"\"
return bar
"""
""" # noqa: E501
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(dec_func: Union[Callable, Runnable]) -> BaseTool:
@@ -185,6 +199,8 @@ def tool(
response_format=response_format,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
@@ -200,6 +216,8 @@ def tool(
return_direct=return_direct,
coroutine=coroutine,
response_format=response_format,
handle_tool_error=handle_tool_error,
handle_validation_error=handle_validation_error,
)
return _make_tool

View File

@@ -168,6 +168,54 @@ def test_subclass_annotated_base_tool_accepted() -> None:
assert tool.args_schema == _MockSchema
def test_decorator_with_error_handling() -> None:
"""Test that error handling works when passed through decorator."""
@tool()
def tool_func(arg1: int, arg2: int) -> float:
"""foo bar tool"""
return arg1 / arg2
with pytest.raises(ToolException):
tool_func.invoke({"arg1": 1, "arg2": 0})
with pytest.raises(ValidationError):
tool_func.invoke({"arg1": 1})
@tool(handle_tool_error="foo", handle_validation_error="bar")
def tool_func_2(arg1: int, arg2: int) -> float:
"""foo bar tool"""
return arg1 / arg2
tool_exception = tool_func_2.invoke({"arg1": 1, "arg2": 0})
assert tool_exception == "foo"
validation_error = tool_func_2.invoke({"arg1": 1})
assert validation_error == "bar"
@tool(handle_tool_error=lambda e: "foo", handle_validation_error=lambda e: "bar")
def tool_func_3(arg1: int, arg2: int) -> float:
"""foo bar tool"""
return arg1 / arg2
tool_exception = tool_func_3.invoke({"arg1": 1, "arg2": 0})
assert tool_exception == "foo"
validation_error = tool_func_3.invoke({"arg1": 1})
assert validation_error == "bar"
@tool(handle_tool_error=True, handle_validation_error=True)
def tool_func_4(arg1: int, arg2: int) -> float:
"""foo bar tool"""
return arg1 / arg2
tool_exception = tool_func_4.invoke({"arg1": 1, "arg2": 0})
assert tool_exception == "division by zero"
validation_error = tool_func_4.invoke({"arg1": 1})
assert validation_error == "Tool input validation error"
def test_decorator_with_specified_schema() -> None:
"""Test that manually specified schemata are passed through to the tool."""
@@ -740,7 +788,7 @@ def test_exception_handling_callable() -> None:
def test_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):
with pytest.raises(ToolException):
_tool.run({})
@@ -771,7 +819,7 @@ async def test_async_exception_handling_callable() -> None:
async def test_async_exception_handling_non_tool_exception() -> None:
_tool = _FakeExceptionTool(exception=ValueError())
with pytest.raises(ValueError):
with pytest.raises(ToolException):
await _tool.arun({})
@@ -867,7 +915,7 @@ def test_validation_error_handling_non_validation_error(
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
with pytest.raises(ToolException):
_tool.run({})
@@ -929,7 +977,7 @@ async def test_async_validation_error_handling_non_validation_error(
return "dummy"
_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) # type: ignore[call-arg]
with pytest.raises(NotImplementedError):
with pytest.raises(ToolException):
await _tool.arun({})
@@ -1470,10 +1518,7 @@ def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
assert tool_.invoke(
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
) == ToolMessage("bar", tool_call_id="123", name="foo")
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
with pytest.raises(expected_error):
with pytest.raises(ToolException):
tool_.invoke({"x": 5})
assert convert_to_openai_function(tool_) == {
@@ -2092,5 +2137,5 @@ def test_structured_tool_direct_init() -> None:
tool = StructuredTool(name="foo", args_schema=fooSchema, coroutine=asyncFoo)
with pytest.raises(NotImplementedError):
with pytest.raises(ToolException):
assert tool.invoke("hello") == "hello"