This commit is contained in:
Harrison Chase
2023-04-16 09:10:56 -07:00
parent 21a1ac36b5
commit db0a9c14cf
3 changed files with 14 additions and 14 deletions

View File

@@ -38,13 +38,13 @@ class InvalidTool(BaseTool):
name = "invalid_tool"
description = "Called when tool name is invalid."
def _run(self, tool_name: str) -> str:
def _run(self, tool_input: str) -> str:
"""Use the tool."""
return f"{tool_name} is not a valid tool, try another one."
return f"{tool_input} is not a valid tool, try another one."
async def _arun(self, tool_name: str) -> str:
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
return f"{tool_name} is not a valid tool, try another one."
return f"{tool_input} is not a valid tool, try another one."
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:

View File

@@ -63,7 +63,7 @@ class BaseTool(ABC, BaseModel):
) -> str:
"""Run the tool."""
if isinstance(tool_input, str):
if len(self.args) > 0:
if len(self.args) > 1:
raise ValueError("Cannot call run on tools with > 1 argument")
key = self.args[0].name
run_input = {key: tool_input}
@@ -100,7 +100,7 @@ class BaseTool(ABC, BaseModel):
) -> str:
"""Run the tool asynchronously."""
if isinstance(tool_input, str):
if len(self.args) > 0:
if len(self.args) > 1:
raise ValueError("Cannot call run on tools with > 1 argument")
key = self.args[0].name
run_input = {key: tool_input}

View File

@@ -28,13 +28,13 @@ class PythonREPLTool(BaseTool):
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
sanitize_input: bool = True
def _run(self, query: str) -> str:
def _run(self, tool_input: str) -> str:
"""Use the tool."""
if self.sanitize_input:
query = query.strip().strip("```")
return self.python_repl.run(query)
tool_input = tool_input.strip().strip("```")
return self.python_repl.run(tool_input)
async def _arun(self, query: str) -> str:
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("PythonReplTool does not support async")
@@ -64,13 +64,13 @@ class PythonAstREPLTool(BaseTool):
)
return values
def _run(self, query: str) -> str:
def _run(self, tool_input: str) -> str:
"""Use the tool."""
try:
if self.sanitize_input:
# Remove the triple backticks from the query.
query = query.strip().strip("```")
tree = ast.parse(query)
tool_input = tool_input.strip().strip("```")
tree = ast.parse(tool_input)
module = ast.Module(tree.body[:-1], type_ignores=[])
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
module_end = ast.Module(tree.body[-1:], type_ignores=[])
@@ -91,6 +91,6 @@ class PythonAstREPLTool(BaseTool):
except Exception as e:
return str(e)
async def _arun(self, query: str) -> str:
async def _arun(self, tool_input: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("PythonReplTool does not support async")