diff --git a/libs/community/langchain_community/tools/shell/tool.py b/libs/community/langchain_community/tools/shell/tool.py index 5f61631059d..e26deb365dc 100644 --- a/libs/community/langchain_community/tools/shell/tool.py +++ b/libs/community/langchain_community/tools/shell/tool.py @@ -1,3 +1,4 @@ +import logging import platform import warnings from typing import Any, List, Optional, Type, Union @@ -8,6 +9,8 @@ from langchain_core.callbacks import ( from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.tools import BaseTool +logger = logging.getLogger(__name__) + class ShellInput(BaseModel): """Commands for the Bash Shell tool.""" @@ -68,10 +71,32 @@ class ShellTool(BaseTool): args_schema: Type[BaseModel] = ShellInput """Schema for input arguments.""" + ask_human_input: bool = False + """ + If True, prompts the user for confirmation (y/n) before executing + a command generated by the language model in the bash shell. + """ + def _run( self, commands: Union[str, List[str]], run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Run commands and return final output.""" - return self.process.run(commands) + + print(f"Executing command:\n {commands}") + + try: + if self.ask_human_input: + user_input = input("Proceed with command execution? (y/n): ").lower() + if user_input == "y": + return self.process.run(commands) + else: + logger.info("Invalid input. User aborted command execution.") + return None + else: + return self.process.run(commands) + + except Exception as e: + logger.error(f"Error during command execution: {e}") + return None diff --git a/libs/community/tests/unit_tests/tools/shell/test_shell.py b/libs/community/tests/unit_tests/tools/shell/test_shell.py index ab6b5abe38c..b792505f1c4 100644 --- a/libs/community/tests/unit_tests/tools/shell/test_shell.py +++ b/libs/community/tests/unit_tests/tools/shell/test_shell.py @@ -1,5 +1,6 @@ import warnings from typing import List +from unittest.mock import patch from langchain_community.tools.shell.tool import ShellInput, ShellTool @@ -65,3 +66,29 @@ def test_shell_tool_run_str() -> None: shell_tool = ShellTool(process=placeholder) result = shell_tool._run(commands="echo 'Hello, World!'") assert result.strip() == "hello" + + +async def test_shell_tool_arun_with_user_confirmation() -> None: + placeholder = PlaceholderProcess(output="hello") + shell_tool = ShellTool(process=placeholder, ask_human_input=True) + + with patch("builtins.input", return_value="y"): + result = await shell_tool._arun(commands=test_commands) + assert result.strip() == "hello" + + with patch("builtins.input", return_value="n"): + result = await shell_tool._arun(commands=test_commands) + assert result is None + + +def test_shell_tool_run_with_user_confirmation() -> None: + placeholder = PlaceholderProcess(output="hello") + shell_tool = ShellTool(process=placeholder, ask_human_input=True) + + with patch("builtins.input", return_value="y"): + result = shell_tool._run(commands="echo 'Hello, World!'") + assert result.strip() == "hello" + + with patch("builtins.input", return_value="n"): + result = shell_tool._run(commands="echo 'Hello, World!'") + assert result is None