mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
cr
This commit is contained in:
@@ -38,13 +38,13 @@ class InvalidTool(BaseTool):
|
||||
name = "invalid_tool"
|
||||
description = "Called when tool name is invalid."
|
||||
|
||||
def _run(self, tool_name: str) -> str:
|
||||
def _run(self, tool_input: str) -> str:
|
||||
"""Use the tool."""
|
||||
return f"{tool_name} is not a valid tool, try another one."
|
||||
return f"{tool_input} is not a valid tool, try another one."
|
||||
|
||||
async def _arun(self, tool_name: str) -> str:
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
return f"{tool_name} is not a valid tool, try another one."
|
||||
return f"{tool_input} is not a valid tool, try another one."
|
||||
|
||||
|
||||
def tool(*args: Union[str, Callable], return_direct: bool = False) -> Callable:
|
||||
|
||||
@@ -63,7 +63,7 @@ class BaseTool(ABC, BaseModel):
|
||||
) -> str:
|
||||
"""Run the tool."""
|
||||
if isinstance(tool_input, str):
|
||||
if len(self.args) > 0:
|
||||
if len(self.args) > 1:
|
||||
raise ValueError("Cannot call run on tools with > 1 argument")
|
||||
key = self.args[0].name
|
||||
run_input = {key: tool_input}
|
||||
@@ -100,7 +100,7 @@ class BaseTool(ABC, BaseModel):
|
||||
) -> str:
|
||||
"""Run the tool asynchronously."""
|
||||
if isinstance(tool_input, str):
|
||||
if len(self.args) > 0:
|
||||
if len(self.args) > 1:
|
||||
raise ValueError("Cannot call run on tools with > 1 argument")
|
||||
key = self.args[0].name
|
||||
run_input = {key: tool_input}
|
||||
|
||||
@@ -28,13 +28,13 @@ class PythonREPLTool(BaseTool):
|
||||
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
|
||||
sanitize_input: bool = True
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
def _run(self, tool_input: str) -> str:
|
||||
"""Use the tool."""
|
||||
if self.sanitize_input:
|
||||
query = query.strip().strip("```")
|
||||
return self.python_repl.run(query)
|
||||
tool_input = tool_input.strip().strip("```")
|
||||
return self.python_repl.run(tool_input)
|
||||
|
||||
async def _arun(self, query: str) -> str:
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("PythonReplTool does not support async")
|
||||
|
||||
@@ -64,13 +64,13 @@ class PythonAstREPLTool(BaseTool):
|
||||
)
|
||||
return values
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
def _run(self, tool_input: str) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
if self.sanitize_input:
|
||||
# Remove the triple backticks from the query.
|
||||
query = query.strip().strip("```")
|
||||
tree = ast.parse(query)
|
||||
tool_input = tool_input.strip().strip("```")
|
||||
tree = ast.parse(tool_input)
|
||||
module = ast.Module(tree.body[:-1], type_ignores=[])
|
||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||
module_end = ast.Module(tree.body[-1:], type_ignores=[])
|
||||
@@ -91,6 +91,6 @@ class PythonAstREPLTool(BaseTool):
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
async def _arun(self, query: str) -> str:
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("PythonReplTool does not support async")
|
||||
|
||||
Reference in New Issue
Block a user