From f238217cea34a6fbb503c3baca9848b8862e0afc Mon Sep 17 00:00:00 2001 From: Krishna Shedbalkar <60742358+krishnashed@users.noreply.github.com> Date: Thu, 18 Jan 2024 02:27:51 +0530 Subject: [PATCH] community[patch]: Basic Logging and Human input to ShellTool (#15932) - **Description:** As Shell tool is very versatile, while integrating it into applications as openai functions, developers have no clue about what command is being executed using the ShellTool. All one can see is: ![image](https://github.com/langchain-ai/langchain/assets/60742358/540e274a-debc-4564-9027-046b91424df3) Summarising my feature request: 1. There's no visibility about what command was executed. 2. There's no mechanism to prevent a command to be executed using ShellTool, like a y/n human input which can be accepted from user to proceed with executing the command., - **Issue:** the issue #15931 it fixes if applicable, - **Dependencies:** There isn't any dependancy, - **Twitter handle:** @krishnashed --- .../langchain_community/tools/shell/tool.py | 27 ++++++++++++++++++- .../unit_tests/tools/shell/test_shell.py | 27 +++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/tools/shell/tool.py b/libs/community/langchain_community/tools/shell/tool.py index 5f61631059d..e26deb365dc 100644 --- a/libs/community/langchain_community/tools/shell/tool.py +++ b/libs/community/langchain_community/tools/shell/tool.py @@ -1,3 +1,4 @@ +import logging import platform import warnings from typing import Any, List, Optional, Type, Union @@ -8,6 +9,8 @@ from langchain_core.callbacks import ( from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.tools import BaseTool +logger = logging.getLogger(__name__) + class ShellInput(BaseModel): """Commands for the Bash Shell tool.""" @@ -68,10 +71,32 @@ class ShellTool(BaseTool): args_schema: Type[BaseModel] = ShellInput """Schema for input arguments.""" + ask_human_input: bool = False + """ + If True, prompts the user for confirmation (y/n) before executing + a command generated by the language model in the bash shell. + """ + def _run( self, commands: Union[str, List[str]], run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Run commands and return final output.""" - return self.process.run(commands) + + print(f"Executing command:\n {commands}") + + try: + if self.ask_human_input: + user_input = input("Proceed with command execution? (y/n): ").lower() + if user_input == "y": + return self.process.run(commands) + else: + logger.info("Invalid input. User aborted command execution.") + return None + else: + return self.process.run(commands) + + except Exception as e: + logger.error(f"Error during command execution: {e}") + return None diff --git a/libs/community/tests/unit_tests/tools/shell/test_shell.py b/libs/community/tests/unit_tests/tools/shell/test_shell.py index ab6b5abe38c..b792505f1c4 100644 --- a/libs/community/tests/unit_tests/tools/shell/test_shell.py +++ b/libs/community/tests/unit_tests/tools/shell/test_shell.py @@ -1,5 +1,6 @@ import warnings from typing import List +from unittest.mock import patch from langchain_community.tools.shell.tool import ShellInput, ShellTool @@ -65,3 +66,29 @@ def test_shell_tool_run_str() -> None: shell_tool = ShellTool(process=placeholder) result = shell_tool._run(commands="echo 'Hello, World!'") assert result.strip() == "hello" + + +async def test_shell_tool_arun_with_user_confirmation() -> None: + placeholder = PlaceholderProcess(output="hello") + shell_tool = ShellTool(process=placeholder, ask_human_input=True) + + with patch("builtins.input", return_value="y"): + result = await shell_tool._arun(commands=test_commands) + assert result.strip() == "hello" + + with patch("builtins.input", return_value="n"): + result = await shell_tool._arun(commands=test_commands) + assert result is None + + +def test_shell_tool_run_with_user_confirmation() -> None: + placeholder = PlaceholderProcess(output="hello") + shell_tool = ShellTool(process=placeholder, ask_human_input=True) + + with patch("builtins.input", return_value="y"): + result = shell_tool._run(commands="echo 'Hello, World!'") + assert result.strip() == "hello" + + with patch("builtins.input", return_value="n"): + result = shell_tool._run(commands="echo 'Hello, World!'") + assert result is None