diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 99e9b741f9f..28593e96766 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -38,13 +38,13 @@ class InvalidTool(BaseTool): name = "invalid_tool" description = "Called when tool name is invalid." - def _run(self, tool_name: str) -> str: + def _run(self, tool_input: str) -> str: """Use the tool.""" - return f"{tool_name} is not a valid tool, try another one." + return f"{tool_input} is not a valid tool, try another one." - async def _arun(self, tool_name: str) -> str: + async def _arun(self, tool_input: str) -> str: """Use the tool asynchronously.""" - return f"{tool_name} is not a valid tool, try another one." + return f"{tool_input} is not a valid tool, try another one." def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable: diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 4fbb477e806..8f82fbffc2a 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -63,7 +63,7 @@ class BaseTool(ABC, BaseModel): ) -> str: """Run the tool.""" if isinstance(tool_input, str): - if len(self.args) > 0: + if len(self.args) > 1: raise ValueError("Cannot call run on tools with > 1 argument") key = self.args[0].name run_input = {key: tool_input} @@ -100,7 +100,7 @@ class BaseTool(ABC, BaseModel): ) -> str: """Run the tool asynchronously.""" if isinstance(tool_input, str): - if len(self.args) > 0: + if len(self.args) > 1: raise ValueError("Cannot call run on tools with > 1 argument") key = self.args[0].name run_input = {key: tool_input} diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index 81339cedc34..70aaefe2adb 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -28,13 +28,13 @@ class PythonREPLTool(BaseTool): python_repl: PythonREPL = Field(default_factory=_get_default_python_repl) sanitize_input: bool = True - def _run(self, query: str) -> str: + def _run(self, tool_input: str) -> str: """Use the tool.""" if self.sanitize_input: - query = query.strip().strip("```") - return self.python_repl.run(query) + tool_input = tool_input.strip().strip("```") + return self.python_repl.run(tool_input) - async def _arun(self, query: str) -> str: + async def _arun(self, tool_input: str) -> str: """Use the tool asynchronously.""" raise NotImplementedError("PythonReplTool does not support async") @@ -64,13 +64,13 @@ class PythonAstREPLTool(BaseTool): ) return values - def _run(self, query: str) -> str: + def _run(self, tool_input: str) -> str: """Use the tool.""" try: if self.sanitize_input: # Remove the triple backticks from the query. - query = query.strip().strip("```") - tree = ast.parse(query) + tool_input = tool_input.strip().strip("```") + tree = ast.parse(tool_input) module = ast.Module(tree.body[:-1], type_ignores=[]) exec(ast.unparse(module), self.globals, self.locals) # type: ignore module_end = ast.Module(tree.body[-1:], type_ignores=[]) @@ -91,6 +91,6 @@ class PythonAstREPLTool(BaseTool): except Exception as e: return str(e) - async def _arun(self, query: str) -> str: + async def _arun(self, tool_input: str) -> str: """Use the tool asynchronously.""" raise NotImplementedError("PythonReplTool does not support async")