langchain/libs/community/langchain_community/tools/shell/tool.py
Erick Friis c2a3021bb0
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
2024-09-13 14:38:45 -07:00

104 lines
3.1 KiB
Python

import logging
import platform
import warnings
from typing import Any, List, Optional, Type, Union
from langchain_core.callbacks import (
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, model_validator
logger = logging.getLogger(__name__)
class ShellInput(BaseModel):
"""Commands for the Bash Shell tool."""
commands: Union[str, List[str]] = Field(
...,
description="List of shell commands to run. Deserialized using json.loads",
)
"""List of shell commands to run."""
@model_validator(mode="before")
@classmethod
def _validate_commands(cls, values: dict) -> Any:
"""Validate commands."""
# TODO: Add real validators
commands = values.get("commands")
if not isinstance(commands, list):
values["commands"] = [commands]
# Warn that the bash tool is not safe
warnings.warn(
"The shell tool has no safeguards by default. Use at your own risk."
)
return values
def _get_default_bash_process() -> Any:
"""Get default bash process."""
try:
from langchain_experimental.llm_bash.bash import BashProcess
except ImportError:
raise ImportError(
"BashProcess has been moved to langchain experimental."
"To use this tool, install langchain-experimental "
"with `pip install langchain-experimental`."
)
return BashProcess(return_err_output=True)
def _get_platform() -> str:
"""Get platform."""
system = platform.system()
if system == "Darwin":
return "MacOS"
return system
class ShellTool(BaseTool):
"""Tool to run shell commands."""
process: Any = Field(default_factory=_get_default_bash_process)
"""Bash process to run commands."""
name: str = "terminal"
"""Name of tool."""
description: str = f"Run shell commands on this {_get_platform()} machine."
"""Description of tool."""
args_schema: Type[BaseModel] = ShellInput
"""Schema for input arguments."""
ask_human_input: bool = False
"""
If True, prompts the user for confirmation (y/n) before executing
a command generated by the language model in the bash shell.
"""
def _run(
self,
commands: Union[str, List[str]],
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Run commands and return final output."""
print(f"Executing command:\n {commands}") # noqa: T201
try:
if self.ask_human_input:
user_input = input("Proceed with command execution? (y/n): ").lower()
if user_input == "y":
return self.process.run(commands)
else:
logger.info("Invalid input. User aborted command execution.")
return None # type: ignore[return-value]
else:
return self.process.run(commands)
except Exception as e:
logger.error(f"Error during command execution: {e}")
return None # type: ignore[return-value]