Compare commits

...

3 Commits

Author SHA1 Message Date
vowelparrot
e73d588ed2 mvoe out 2023-04-22 18:28:20 -07:00
vowelparrot
5da420add9 Merge branch 'master' into vwp/default_dont_raise 2023-04-22 18:25:23 -07:00
vowelparrot
b0b8e1a6ab Default to Not Raise Errors
For tools, this is an option to not raise errors

Cons: adds yet another argument to the tool..
2023-04-20 22:55:42 -07:00
4 changed files with 20 additions and 5 deletions

View File

@@ -66,6 +66,7 @@ def tool(
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
raise_errors: bool = False,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.
@@ -77,6 +78,8 @@ def tool(
infer_schema: Whether to infer the schema of the arguments from
the function's signature. This also makes the resultant tool
accept a dictionary input to its `run()` function.
raise_errors: Whether to raise exceptions when running the tool
rather than returning a string with the error message.
Requires:
- Function must be of type (str) -> str
@@ -111,6 +114,7 @@ def tool(
args_schema=_args_schema,
description=description,
return_direct=return_direct,
raise_errors=raise_errors,
)
return tool_

View File

@@ -29,6 +29,7 @@ class BaseTool(ABC, BaseModel):
return_direct: bool = False
verbose: bool = False
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
raise_errors: bool = False
class Config:
"""Configuration for this pydantic object."""
@@ -78,6 +79,12 @@ class BaseTool(ABC, BaseModel):
async def _arun(self, *args: Any, **kwargs: Any) -> str:
"""Use the tool asynchronously."""
def handle_error(self, error: Exception) -> str:
"""Handle an error raised by the tool."""
if self.raise_errors:
raise error
return f"Error: {error}, {type(error)}"
def run(
self,
tool_input: Union[str, Dict],
@@ -103,8 +110,10 @@ class BaseTool(ABC, BaseModel):
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
self.callback_manager.on_tool_error(e, verbose=verbose)
if isinstance(e, KeyboardInterrupt):
raise e
return self.handle_error(e)
self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs
)
@@ -149,7 +158,9 @@ class BaseTool(ABC, BaseModel):
await self.callback_manager.on_tool_error(e, verbose=verbose_)
else:
self.callback_manager.on_tool_error(e, verbose=verbose_)
raise e
if isinstance(e, KeyboardInterrupt):
raise e
return self.handle_error(e)
if self.callback_manager.is_async:
await self.callback_manager.on_tool_end(
observation, verbose=verbose_, color=color, name=self.name, **kwargs

View File

@@ -123,7 +123,7 @@ def test_unnamed_tool_decorator_return_direct() -> None:
def test_tool_with_kwargs() -> None:
"""Test functionality when only return direct is provided."""
@tool(return_direct=True)
@tool(return_direct=True, raise_errors=True)
def search_api(
arg_1: float,
ping: str = "hi",

View File

@@ -20,7 +20,7 @@ def test_write_file_with_root_dir() -> None:
def test_write_file_errs_outside_root_dir() -> None:
"""Test the WriteFile tool when a root dir is specified."""
with TemporaryDirectory() as temp_dir:
tool = WriteFileTool(root_dir=temp_dir)
tool = WriteFileTool(root_dir=temp_dir, raise_errors=True)
with pytest.raises(ValueError):
tool.run({"file_path": "../file.txt", "text": "Hello, world!"})