diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 000000000..26d33521a --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/DB-GPT.iml b/.idea/DB-GPT.iml new file mode 100644 index 000000000..9725c1b01 --- /dev/null +++ b/.idea/DB-GPT.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 000000000..105ce2da2 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 000000000..e965926fe --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 000000000..a22f312c8 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 000000000..94a25f7f4 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/pilot/agent/base_open_ai_plugin.py b/pilot/agent/base_open_ai_plugin.py new file mode 100644 index 000000000..046295c0d --- /dev/null +++ b/pilot/agent/base_open_ai_plugin.py @@ -0,0 +1,199 @@ +"""Handles loading of plugins.""" +from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar + +from auto_gpt_plugin_template import AutoGPTPluginTemplate + +PromptGenerator = TypeVar("PromptGenerator") + + +class Message(TypedDict): + role: str + content: str + + +class BaseOpenAIPlugin(AutoGPTPluginTemplate): + """ + This is a BaseOpenAIPlugin class for generating Auto-GPT plugins. + """ + + def __init__(self, manifests_specs_clients: dict): + # super().__init__() + self._name = manifests_specs_clients["manifest"]["name_for_model"] + self._version = manifests_specs_clients["manifest"]["schema_version"] + self._description = manifests_specs_clients["manifest"]["description_for_model"] + self._client = manifests_specs_clients["client"] + self._manifest = manifests_specs_clients["manifest"] + self._openapi_spec = manifests_specs_clients["openapi_spec"] + + def can_handle_on_response(self) -> bool: + """This method is called to check that the plugin can + handle the on_response method. + Returns: + bool: True if the plugin can handle the on_response method.""" + return False + + def on_response(self, response: str, *args, **kwargs) -> str: + """This method is called when a response is received from the model.""" + return response + + def can_handle_post_prompt(self) -> bool: + """This method is called to check that the plugin can + handle the post_prompt method. + Returns: + bool: True if the plugin can handle the post_prompt method.""" + return False + + def post_prompt(self, prompt: PromptGenerator) -> PromptGenerator: + """This method is called just after the generate_prompt is called, + but actually before the prompt is generated. + Args: + prompt (PromptGenerator): The prompt generator. + Returns: + PromptGenerator: The prompt generator. + """ + return prompt + + def can_handle_on_planning(self) -> bool: + """This method is called to check that the plugin can + handle the on_planning method. + Returns: + bool: True if the plugin can handle the on_planning method.""" + return False + + def on_planning( + self, prompt: PromptGenerator, messages: List[Message] + ) -> Optional[str]: + """This method is called before the planning chat completion is done. + Args: + prompt (PromptGenerator): The prompt generator. + messages (List[str]): The list of messages. + """ + pass + + def can_handle_post_planning(self) -> bool: + """This method is called to check that the plugin can + handle the post_planning method. + Returns: + bool: True if the plugin can handle the post_planning method.""" + return False + + def post_planning(self, response: str) -> str: + """This method is called after the planning chat completion is done. + Args: + response (str): The response. + Returns: + str: The resulting response. + """ + return response + + def can_handle_pre_instruction(self) -> bool: + """This method is called to check that the plugin can + handle the pre_instruction method. + Returns: + bool: True if the plugin can handle the pre_instruction method.""" + return False + + def pre_instruction(self, messages: List[Message]) -> List[Message]: + """This method is called before the instruction chat is done. + Args: + messages (List[Message]): The list of context messages. + Returns: + List[Message]: The resulting list of messages. + """ + return messages + + def can_handle_on_instruction(self) -> bool: + """This method is called to check that the plugin can + handle the on_instruction method. + Returns: + bool: True if the plugin can handle the on_instruction method.""" + return False + + def on_instruction(self, messages: List[Message]) -> Optional[str]: + """This method is called when the instruction chat is done. + Args: + messages (List[Message]): The list of context messages. + Returns: + Optional[str]: The resulting message. + """ + pass + + def can_handle_post_instruction(self) -> bool: + """This method is called to check that the plugin can + handle the post_instruction method. + Returns: + bool: True if the plugin can handle the post_instruction method.""" + return False + + def post_instruction(self, response: str) -> str: + """This method is called after the instruction chat is done. + Args: + response (str): The response. + Returns: + str: The resulting response. + """ + return response + + def can_handle_pre_command(self) -> bool: + """This method is called to check that the plugin can + handle the pre_command method. + Returns: + bool: True if the plugin can handle the pre_command method.""" + return False + + def pre_command( + self, command_name: str, arguments: Dict[str, Any] + ) -> Tuple[str, Dict[str, Any]]: + """This method is called before the command is executed. + Args: + command_name (str): The command name. + arguments (Dict[str, Any]): The arguments. + Returns: + Tuple[str, Dict[str, Any]]: The command name and the arguments. + """ + return command_name, arguments + + def can_handle_post_command(self) -> bool: + """This method is called to check that the plugin can + handle the post_command method. + Returns: + bool: True if the plugin can handle the post_command method.""" + return False + + def post_command(self, command_name: str, response: str) -> str: + """This method is called after the command is executed. + Args: + command_name (str): The command name. + response (str): The response. + Returns: + str: The resulting response. + """ + return response + + def can_handle_chat_completion( + self, messages: Dict[Any, Any], model: str, temperature: float, max_tokens: int + ) -> bool: + """This method is called to check that the plugin can + handle the chat_completion method. + Args: + messages (List[Message]): The messages. + model (str): The model name. + temperature (float): The temperature. + max_tokens (int): The max tokens. + Returns: + bool: True if the plugin can handle the chat_completion method.""" + return False + + def handle_chat_completion( + self, messages: List[Message], model: str, temperature: float, max_tokens: int + ) -> str: + """This method is called when the chat completion is done. + Args: + messages (List[Message]): The messages. + model (str): The model name. + temperature (float): The temperature. + max_tokens (int): The max tokens. + Returns: + str: The resulting response. + """ + pass diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 799bcd59b..ffc2a13ed 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -29,7 +29,14 @@ class Config(metaclass=Singleton): "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36" " (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36", ) - + + self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY") + self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID") + self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID") + + self.use_mac_os_tts = False + self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS") + # milvus or zilliz cloud configuration self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530") self.milvus_username = os.getenv("MILVUS_USERNAME") @@ -37,14 +44,20 @@ class Config(metaclass=Singleton): self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" + self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins") + self.plugins: List[AutoGPTPluginTemplate] = [] + self.plugins_openai = [] + plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS") if plugins_allowlist: self.plugins_allowlist = plugins_allowlist.split(",") else: - self.plugins_allowlist = [] + self.plugins_allowlist = [] - plugins_denylist = os.getenv("DENYLISTED_PLUGINS") + plugins_denylist = os.getenv("DENYLISTED_PLUGINS") if plugins_denylist: + self.plugins_denylist = plugins_denylist.split(",") + else: self.plugins_denylist = [] def set_debug_mode(self, value: bool) -> None: diff --git a/pilot/log/__init__.py b/pilot/log/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/log/json_handler.py b/pilot/log/json_handler.py new file mode 100644 index 000000000..51ae9ae03 --- /dev/null +++ b/pilot/log/json_handler.py @@ -0,0 +1,20 @@ +import json +import logging + + +class JsonFileHandler(logging.FileHandler): + def __init__(self, filename, mode="a", encoding=None, delay=False): + super().__init__(filename, mode, encoding, delay) + + def emit(self, record): + json_data = json.loads(self.format(record)) + with open(self.baseFilename, "w", encoding="utf-8") as f: + json.dump(json_data, f, ensure_ascii=False, indent=4) + + +import logging + + +class JsonFormatter(logging.Formatter): + def format(self, record): + return record.msg diff --git a/pilot/logs.py b/pilot/logs.py new file mode 100644 index 000000000..b5a1fad82 --- /dev/null +++ b/pilot/logs.py @@ -0,0 +1,287 @@ +import logging +import os +import random +import re +import time +from logging import LogRecord +from typing import Any + +from colorama import Fore, Style + +from pilot.log.json_handler import JsonFileHandler, JsonFormatter +from pilot.singleton import Singleton +from pilot.speech import say_text + + +class Logger(metaclass=Singleton): + """ + Logger that handle titles in different colors. + Outputs logs in console, activity.log, and errors.log + For console handler: simulates typing + """ + + def __init__(self): + # create log directory if it doesn't exist + this_files_dir_path = os.path.dirname(__file__) + log_dir = os.path.join(this_files_dir_path, "../logs") + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + log_file = "activity.log" + error_file = "error.log" + + console_formatter = DbGptFormatter("%(title_color)s %(message)s") + + # Create a handler for console which simulate typing + self.typing_console_handler = TypingConsoleHandler() + self.typing_console_handler.setLevel(logging.INFO) + self.typing_console_handler.setFormatter(console_formatter) + + # Create a handler for console without typing simulation + self.console_handler = ConsoleHandler() + self.console_handler.setLevel(logging.DEBUG) + self.console_handler.setFormatter(console_formatter) + + # Info handler in activity.log + self.file_handler = logging.FileHandler( + os.path.join(log_dir, log_file), "a", "utf-8" + ) + self.file_handler.setLevel(logging.DEBUG) + info_formatter = DbGptFormatter( + "%(asctime)s %(levelname)s %(title)s %(message_no_color)s" + ) + self.file_handler.setFormatter(info_formatter) + + # Error handler error.log + error_handler = logging.FileHandler( + os.path.join(log_dir, error_file), "a", "utf-8" + ) + error_handler.setLevel(logging.ERROR) + error_formatter = DbGptFormatter( + "%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d %(title)s" + " %(message_no_color)s" + ) + error_handler.setFormatter(error_formatter) + + self.typing_logger = logging.getLogger("TYPER") + self.typing_logger.addHandler(self.typing_console_handler) + self.typing_logger.addHandler(self.file_handler) + self.typing_logger.addHandler(error_handler) + self.typing_logger.setLevel(logging.DEBUG) + + self.logger = logging.getLogger("LOGGER") + self.logger.addHandler(self.console_handler) + self.logger.addHandler(self.file_handler) + self.logger.addHandler(error_handler) + self.logger.setLevel(logging.DEBUG) + + self.json_logger = logging.getLogger("JSON_LOGGER") + self.json_logger.addHandler(self.file_handler) + self.json_logger.addHandler(error_handler) + self.json_logger.setLevel(logging.DEBUG) + + self.speak_mode = False + self.chat_plugins = [] + + def typewriter_log( + self, title="", title_color="", content="", speak_text=False, level=logging.INFO + ): + if speak_text and self.speak_mode: + say_text(f"{title}. {content}") + + for plugin in self.chat_plugins: + plugin.report(f"{title}. {content}") + + if content: + if isinstance(content, list): + content = " ".join(content) + else: + content = "" + + self.typing_logger.log( + level, content, extra={"title": title, "color": title_color} + ) + + def debug( + self, + message, + title="", + title_color="", + ): + self._log(title, title_color, message, logging.DEBUG) + + def info( + self, + message, + title="", + title_color="", + ): + self._log(title, title_color, message, logging.INFO) + + def warn( + self, + message, + title="", + title_color="", + ): + self._log(title, title_color, message, logging.WARN) + + def error(self, title, message=""): + self._log(title, Fore.RED, message, logging.ERROR) + + def _log( + self, + title: str = "", + title_color: str = "", + message: str = "", + level=logging.INFO, + ): + if message: + if isinstance(message, list): + message = " ".join(message) + self.logger.log( + level, message, extra={"title": str(title), "color": str(title_color)} + ) + + def set_level(self, level): + self.logger.setLevel(level) + self.typing_logger.setLevel(level) + + def double_check(self, additionalText=None): + if not additionalText: + additionalText = ( + "Please ensure you've setup and configured everything" + " correctly. Read https://github.com/Torantulino/Auto-GPT#readme to " + "double check. You can also create a github issue or join the discord" + " and ask there!" + ) + + self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText) + + def log_json(self, data: Any, file_name: str) -> None: + # Define log directory + this_files_dir_path = os.path.dirname(__file__) + log_dir = os.path.join(this_files_dir_path, "../logs") + + # Create a handler for JSON files + json_file_path = os.path.join(log_dir, file_name) + json_data_handler = JsonFileHandler(json_file_path) + json_data_handler.setFormatter(JsonFormatter()) + + # Log the JSON data using the custom file handler + self.json_logger.addHandler(json_data_handler) + self.json_logger.debug(data) + self.json_logger.removeHandler(json_data_handler) + + def get_log_directory(self): + this_files_dir_path = os.path.dirname(__file__) + log_dir = os.path.join(this_files_dir_path, "../logs") + return os.path.abspath(log_dir) + +""" +Output stream to console using simulated typing +""" + +class TypingConsoleHandler(logging.StreamHandler): + def emit(self, record): + min_typing_speed = 0.05 + max_typing_speed = 0.01 + + msg = self.format(record) + try: + words = msg.split() + for i, word in enumerate(words): + print(word, end="", flush=True) + if i < len(words) - 1: + print(" ", end="", flush=True) + typing_speed = random.uniform(min_typing_speed, max_typing_speed) + time.sleep(typing_speed) + # type faster after each word + min_typing_speed = min_typing_speed * 0.95 + max_typing_speed = max_typing_speed * 0.95 + print() + except Exception: + self.handleError(record) + +class ConsoleHandler(logging.StreamHandler): + def emit(self, record) -> None: + msg = self.format(record) + try: + print(msg) + except Exception: + self.handleError(record) + + +class DbGptFormatter(logging.Formatter): + """ + Allows to handle custom placeholders 'title_color' and 'message_no_color'. + To use this formatter, make sure to pass 'color', 'title' as log extras. + """ + + def format(self, record: LogRecord) -> str: + if hasattr(record, "color"): + record.title_color = ( + getattr(record, "color") + + getattr(record, "title", "") + + " " + + Style.RESET_ALL + ) + else: + record.title_color = getattr(record, "title", "") + + # Add this line to set 'title' to an empty string if it doesn't exist + record.title = getattr(record, "title", "") + + if hasattr(record, "msg"): + record.message_no_color = remove_color_codes(getattr(record, "msg")) + else: + record.message_no_color = "" + return super().format(record) + + +def remove_color_codes(s: str) -> str: + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", s) + + +logger = Logger() + + +def print_assistant_thoughts( + ai_name: object, + assistant_reply_json_valid: object, + speak_mode: bool = False, +) -> None: + assistant_thoughts_reasoning = None + assistant_thoughts_plan = None + assistant_thoughts_speak = None + assistant_thoughts_criticism = None + + assistant_thoughts = assistant_reply_json_valid.get("thoughts", {}) + assistant_thoughts_text = assistant_thoughts.get("text") + if assistant_thoughts: + assistant_thoughts_reasoning = assistant_thoughts.get("reasoning") + assistant_thoughts_plan = assistant_thoughts.get("plan") + assistant_thoughts_criticism = assistant_thoughts.get("criticism") + assistant_thoughts_speak = assistant_thoughts.get("speak") + logger.typewriter_log( + f"{ai_name.upper()} THOUGHTS:", Fore.YELLOW, f"{assistant_thoughts_text}" + ) + logger.typewriter_log("REASONING:", Fore.YELLOW, f"{assistant_thoughts_reasoning}") + if assistant_thoughts_plan: + logger.typewriter_log("PLAN:", Fore.YELLOW, "") + # If it's a list, join it into a string + if isinstance(assistant_thoughts_plan, list): + assistant_thoughts_plan = "\n".join(assistant_thoughts_plan) + elif isinstance(assistant_thoughts_plan, dict): + assistant_thoughts_plan = str(assistant_thoughts_plan) + + # Split the input_string using the newline character and dashes + lines = assistant_thoughts_plan.split("\n") + for line in lines: + line = line.lstrip("- ") + logger.typewriter_log("- ", Fore.GREEN, line.strip()) + logger.typewriter_log("CRITICISM:", Fore.YELLOW, f"{assistant_thoughts_criticism}") + # Speak the assistant's thoughts + if speak_mode and assistant_thoughts_speak: + say_text(assistant_thoughts_speak) diff --git a/pilot/plugins.py b/pilot/plugins.py new file mode 100644 index 000000000..72b0a13a8 --- /dev/null +++ b/pilot/plugins.py @@ -0,0 +1,275 @@ +"""加载组件""" + +import importlib +import json +import os +import zipfile +from pathlib import Path +from typing import List, Optional, Tuple +from urllib.parse import urlparse +from zipimport import zipimporter + +import openapi_python_client +import requests +from auto_gpt_plugin_template import AutoGPTPluginTemplate +from openapi_python_client.cli import Config as OpenAPIConfig + +from pilot.configs.config import Config +from pilot.logs import logger +from pilot.agent.base_open_ai_plugin import BaseOpenAIPlugin + +def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]: + """ + 加载zip文件的插件,完全兼容Auto_gpt_plugin + + Args: + zip_path (str): Path to the zipfile. + debug (bool, optional): Enable debug logging. Defaults to False. + + Returns: + list[str]: The list of module names found or empty list if none were found. + """ + result = [] + with zipfile.ZipFile(zip_path, "r") as zfile: + for name in zfile.namelist(): + if name.endswith("__init__.py") and not name.startswith("__MACOSX"): + logger.debug(f"Found module '{name}' in the zipfile at: {name}") + result.append(name) + if len(result) == 0: + logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.") + return result + +def write_dict_to_json_file(data: dict, file_path: str) -> None: + """ + Write a dictionary to a JSON file. + Args: + data (dict): Dictionary to write. + file_path (str): Path to the file. + """ + with open(file_path, "w") as file: + json.dump(data, file, indent=4) + + + +def fetch_openai_plugins_manifest_and_spec(cfg: Config) -> dict: + """ + Fetch the manifest for a list of OpenAI plugins. + Args: + urls (List): List of URLs to fetch. + Returns: + dict: per url dictionary of manifest and spec. + """ + # TODO add directory scan + manifests = {} + for url in cfg.plugins_openai: + openai_plugin_client_dir = f"{cfg.plugins_dir}/openai/{urlparse(url).netloc}" + create_directory_if_not_exists(openai_plugin_client_dir) + if not os.path.exists(f"{openai_plugin_client_dir}/ai-plugin.json"): + try: + response = requests.get(f"{url}/.well-known/ai-plugin.json") + if response.status_code == 200: + manifest = response.json() + if manifest["schema_version"] != "v1": + logger.warn( + f"Unsupported manifest version: {manifest['schem_version']} for {url}" + ) + continue + if manifest["api"]["type"] != "openapi": + logger.warn( + f"Unsupported API type: {manifest['api']['type']} for {url}" + ) + continue + write_dict_to_json_file( + manifest, f"{openai_plugin_client_dir}/ai-plugin.json" + ) + else: + logger.warn( + f"Failed to fetch manifest for {url}: {response.status_code}" + ) + except requests.exceptions.RequestException as e: + logger.warn(f"Error while requesting manifest from {url}: {e}") + else: + logger.info(f"Manifest for {url} already exists") + manifest = json.load(open(f"{openai_plugin_client_dir}/ai-plugin.json")) + if not os.path.exists(f"{openai_plugin_client_dir}/openapi.json"): + openapi_spec = openapi_python_client._get_document( + url=manifest["api"]["url"], path=None, timeout=5 + ) + write_dict_to_json_file( + openapi_spec, f"{openai_plugin_client_dir}/openapi.json" + ) + else: + logger.info(f"OpenAPI spec for {url} already exists") + openapi_spec = json.load(open(f"{openai_plugin_client_dir}/openapi.json")) + manifests[url] = {"manifest": manifest, "openapi_spec": openapi_spec} + return manifests + + + +def create_directory_if_not_exists(directory_path: str) -> bool: + """ + Create a directory if it does not exist. + Args: + directory_path (str): Path to the directory. + Returns: + bool: True if the directory was created, else False. + """ + if not os.path.exists(directory_path): + try: + os.makedirs(directory_path) + logger.debug(f"Created directory: {directory_path}") + return True + except OSError as e: + logger.warn(f"Error creating directory {directory_path}: {e}") + return False + else: + logger.info(f"Directory {directory_path} already exists") + return True + + +def initialize_openai_plugins( + manifests_specs: dict, cfg: Config, debug: bool = False +) -> dict: + """ + Initialize OpenAI plugins. + Args: + manifests_specs (dict): per url dictionary of manifest and spec. + cfg (Config): Config instance including plugins config + debug (bool, optional): Enable debug logging. Defaults to False. + Returns: + dict: per url dictionary of manifest, spec and client. + """ + openai_plugins_dir = f"{cfg.plugins_dir}/openai" + if create_directory_if_not_exists(openai_plugins_dir): + for url, manifest_spec in manifests_specs.items(): + openai_plugin_client_dir = f"{openai_plugins_dir}/{urlparse(url).hostname}" + _meta_option = (openapi_python_client.MetaType.SETUP,) + _config = OpenAPIConfig( + **{ + "project_name_override": "client", + "package_name_override": "client", + } + ) + prev_cwd = Path.cwd() + os.chdir(openai_plugin_client_dir) + Path("ai-plugin.json") + if not os.path.exists("client"): + client_results = openapi_python_client.create_new_client( + url=manifest_spec["manifest"]["api"]["url"], + path=None, + meta=_meta_option, + config=_config, + ) + if client_results: + logger.warn( + f"Error creating OpenAPI client: {client_results[0].header} \n" + f" details: {client_results[0].detail}" + ) + continue + spec = importlib.util.spec_from_file_location( + "client", "client/client/client.py" + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + client = module.Client(base_url=url) + os.chdir(prev_cwd) + manifest_spec["client"] = client + return manifests_specs + + +def instantiate_openai_plugin_clients( + manifests_specs_clients: dict, cfg: Config, debug: bool = False +) -> dict: + """ + Instantiates BaseOpenAIPlugin instances for each OpenAI plugin. + Args: + manifests_specs_clients (dict): per url dictionary of manifest, spec and client. + cfg (Config): Config instance including plugins config + debug (bool, optional): Enable debug logging. Defaults to False. + Returns: + plugins (dict): per url dictionary of BaseOpenAIPlugin instances. + + """ + plugins = {} + for url, manifest_spec_client in manifests_specs_clients.items(): + plugins[url] = BaseOpenAIPlugin(manifest_spec_client) + return plugins + + +def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: + """Scan the plugins directory for plugins and loads them. + + Args: + cfg (Config): Config instance including plugins config + debug (bool, optional): Enable debug logging. Defaults to False. + + Returns: + List[Tuple[str, Path]]: List of plugins. + """ + loaded_plugins = [] + # Generic plugins + plugins_path_path = Path(cfg.plugins_dir) + + logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}") + logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}") + + for plugin in plugins_path_path.glob("*.zip"): + if moduleList := inspect_zip_for_modules(str(plugin), debug): + for module in moduleList: + plugin = Path(plugin) + module = Path(module) + logger.debug(f"Plugin: {plugin} Module: {module}") + zipped_package = zipimporter(str(plugin)) + zipped_module = zipped_package.load_module(str(module.parent)) + for key in dir(zipped_module): + if key.startswith("__"): + continue + a_module = getattr(zipped_module, key) + a_keys = dir(a_module) + if ( + "_abc_impl" in a_keys + and a_module.__name__ != "AutoGPTPluginTemplate" + and denylist_allowlist_check(a_module.__name__, cfg) + ): + loaded_plugins.append(a_module()) + # OpenAI plugins + if cfg.plugins_openai: + manifests_specs = fetch_openai_plugins_manifest_and_spec(cfg) + if manifests_specs.keys(): + manifests_specs_clients = initialize_openai_plugins( + manifests_specs, cfg, debug + ) + for url, openai_plugin_meta in manifests_specs_clients.items(): + if denylist_allowlist_check(url, cfg): + plugin = BaseOpenAIPlugin(openai_plugin_meta) + loaded_plugins.append(plugin) + + if loaded_plugins: + logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------") + for plugin in loaded_plugins: + logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}") + return loaded_plugins + + +def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool: + """Check if the plugin is in the allowlist or denylist. + + Args: + plugin_name (str): Name of the plugin. + cfg (Config): Config object. + + Returns: + True or False + """ + logger.debug(f"Checking if plugin {plugin_name} should be loaded") + if plugin_name in cfg.plugins_denylist: + logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.") + return False + if plugin_name in cfg.plugins_allowlist: + logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.") + return True + ack = input( + f"WARNING: Plugin {plugin_name} found. But not in the" + f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): " + ) + return ack.lower() == cfg.authorise_key diff --git a/pilot/speech/__init__.py b/pilot/speech/__init__.py new file mode 100644 index 000000000..58bfda398 --- /dev/null +++ b/pilot/speech/__init__.py @@ -0,0 +1,3 @@ +from pilot.speech.say import say_text + +__all__ = ["say_text"] diff --git a/pilot/speech/base.py b/pilot/speech/base.py new file mode 100644 index 000000000..55154df10 --- /dev/null +++ b/pilot/speech/base.py @@ -0,0 +1,50 @@ +"""Base class for all voice classes.""" +import abc +from threading import Lock + +from pilot.singleton import AbstractSingleton + + +class VoiceBase(AbstractSingleton): + """ + Base class for all voice classes. + """ + + def __init__(self): + """ + Initialize the voice class. + """ + self._url = None + self._headers = None + self._api_key = None + self._voices = [] + self._mutex = Lock() + self._setup() + + def say(self, text: str, voice_index: int = 0) -> bool: + """ + Say the given text. + + Args: + text (str): The text to say. + voice_index (int): The index of the voice to use. + """ + with self._mutex: + return self._speech(text, voice_index) + + @abc.abstractmethod + def _setup(self) -> None: + """ + Setup the voices, API key, etc. + """ + pass + + @abc.abstractmethod + def _speech(self, text: str, voice_index: int = 0) -> bool: + """ + Play the given text. + + Args: + text (str): The text to play. + """ + pass diff --git a/pilot/speech/brian.py b/pilot/speech/brian.py new file mode 100644 index 000000000..505c9a6f8 --- /dev/null +++ b/pilot/speech/brian.py @@ -0,0 +1,43 @@ +import logging +import os + +import requests +from playsound import playsound + +from pilot.speech.base import VoiceBase + + +class BrianSpeech(VoiceBase): + """Brian speech module for autogpt""" + + def _setup(self) -> None: + """Setup the voices, API key, etc.""" + pass + + def _speech(self, text: str, _: int = 0) -> bool: + """Speak text using Brian with the streamelements API + + Args: + text (str): The text to speak + + Returns: + bool: True if the request was successful, False otherwise + """ + tts_url = ( + f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}" + ) + response = requests.get(tts_url) + + if response.status_code == 200: + with open("speech.mp3", "wb") as f: + f.write(response.content) + playsound("speech.mp3") + os.remove("speech.mp3") + return True + else: + logging.error( + "Request failed with status code: %s, response content: %s", + response.status_code, + response.content, + ) + return False diff --git a/pilot/speech/eleven_labs.py b/pilot/speech/eleven_labs.py new file mode 100644 index 000000000..8a93ab5ae --- /dev/null +++ b/pilot/speech/eleven_labs.py @@ -0,0 +1,88 @@ +"""ElevenLabs speech module""" +import os + +import requests +from playsound import playsound + +from pilot.configs.config import Config +from pilot.speech.base import VoiceBase + +PLACEHOLDERS = {"your-voice-id"} + + +class ElevenLabsSpeech(VoiceBase): + """ElevenLabs speech class""" + + def _setup(self) -> None: + """Set up the voices, API key, etc. + + Returns: + None: None + """ + + cfg = Config() + default_voices = ["ErXwobaYiN019PkySvjV", "EXAVITQu4vr4xnSDxMaL"] + voice_options = { + "Rachel": "21m00Tcm4TlvDq8ikWAM", + "Domi": "AZnzlk1XvdvUeBnXmlld", + "Bella": "EXAVITQu4vr4xnSDxMaL", + "Antoni": "ErXwobaYiN019PkySvjV", + "Elli": "MF3mGyEYCl7XYWbV9V6O", + "Josh": "TxGEqnHWrfWFTfGW9XjX", + "Arnold": "VR6AewLTigWG4xSOukaG", + "Adam": "pNInz6obpgDQGcFmaJgB", + "Sam": "yoZ06aMxZJJ28mfd3POQ", + } + self._headers = { + "Content-Type": "application/json", + "xi-api-key": cfg.elevenlabs_api_key, + } + self._voices = default_voices.copy() + if cfg.elevenlabs_voice_1_id in voice_options: + cfg.elevenlabs_voice_1_id = voice_options[cfg.elevenlabs_voice_1_id] + if cfg.elevenlabs_voice_2_id in voice_options: + cfg.elevenlabs_voice_2_id = voice_options[cfg.elevenlabs_voice_2_id] + self._use_custom_voice(cfg.elevenlabs_voice_1_id, 0) + self._use_custom_voice(cfg.elevenlabs_voice_2_id, 1) + + def _use_custom_voice(self, voice, voice_index) -> None: + """Use a custom voice if provided and not a placeholder + + Args: + voice (str): The voice ID + voice_index (int): The voice index + + Returns: + None: None + """ + # Placeholder values that should be treated as empty + if voice and voice not in PLACEHOLDERS: + self._voices[voice_index] = voice + + def _speech(self, text: str, voice_index: int = 0) -> bool: + """Speak text using elevenlabs.io's API + + Args: + text (str): The text to speak + voice_index (int, optional): The voice to use. Defaults to 0. + + Returns: + bool: True if the request was successful, False otherwise + """ + from autogpt.logs import logger + + tts_url = ( + f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}" + ) + response = requests.post(tts_url, headers=self._headers, json={"text": text}) + + if response.status_code == 200: + with open("speech.mpeg", "wb") as f: + f.write(response.content) + playsound("speech.mpeg", True) + os.remove("speech.mpeg") + return True + else: + logger.warn("Request failed with status code:", response.status_code) + logger.info("Response content:", response.content) + return False diff --git a/pilot/speech/gtts.py b/pilot/speech/gtts.py new file mode 100644 index 000000000..7ad164f30 --- /dev/null +++ b/pilot/speech/gtts.py @@ -0,0 +1,22 @@ +""" GTTS Voice. """ +import os + +import gtts +from playsound import playsound + +from pilot.speech.base import VoiceBase + + +class GTTSVoice(VoiceBase): + """GTTS Voice.""" + + def _setup(self) -> None: + pass + + def _speech(self, text: str, _: int = 0) -> bool: + """Play the given text.""" + tts = gtts.gTTS(text) + tts.save("speech.mp3") + playsound("speech.mp3", True) + os.remove("speech.mp3") + return True diff --git a/pilot/speech/macos_tts.py b/pilot/speech/macos_tts.py new file mode 100644 index 000000000..51292c240 --- /dev/null +++ b/pilot/speech/macos_tts.py @@ -0,0 +1,21 @@ +""" MacOS TTS Voice. """ +import os + +from pilot.speech.base import VoiceBase + + +class MacOSTTS(VoiceBase): + """MacOS TTS Voice.""" + + def _setup(self) -> None: + pass + + def _speech(self, text: str, voice_index: int = 0) -> bool: + """Play the given text.""" + if voice_index == 0: + os.system(f'say "{text}"') + elif voice_index == 1: + os.system(f'say -v "Ava (Premium)" "{text}"') + else: + os.system(f'say -v Samantha "{text}"') + return True diff --git a/pilot/speech/say.py b/pilot/speech/say.py new file mode 100644 index 000000000..b0f6b0516 --- /dev/null +++ b/pilot/speech/say.py @@ -0,0 +1,46 @@ +""" Text to speech module """ +import threading +from threading import Semaphore + +from pilot.configs.config import Config +from pilot.speech.base import VoiceBase +from pilot.speech.brian import BrianSpeech +from pilot.speech.eleven_labs import ElevenLabsSpeech +from pilot.speech.gtts import GTTSVoice +from pilot.speech.macos_tts import MacOSTTS + +_QUEUE_SEMAPHORE = Semaphore( + 1 +) # The amount of sounds to queue before blocking the main thread + + +def say_text(text: str, voice_index: int = 0) -> None: + """Speak the given text using the given voice index""" + cfg = Config() + default_voice_engine, voice_engine = _get_voice_engine(cfg) + + def speak() -> None: + success = voice_engine.say(text, voice_index) + if not success: + default_voice_engine.say(text) + + _QUEUE_SEMAPHORE.release() + + _QUEUE_SEMAPHORE.acquire(True) + thread = threading.Thread(target=speak) + thread.start() + + +def _get_voice_engine(config: Config) -> tuple[VoiceBase, VoiceBase]: + """Get the voice engine to use for the given configuration""" + default_voice_engine = GTTSVoice() + if config.elevenlabs_api_key: + voice_engine = ElevenLabsSpeech() + elif config.use_mac_os_tts == "True": + voice_engine = MacOSTTS() + elif config.use_brian_tts == "True": + voice_engine = BrianSpeech() + else: + voice_engine = GTTSVoice() + + return default_voice_engine, voice_engine diff --git a/plugins/Auto-GPT-TiDB-Serverless-Plugin-main.zip b/plugins/Auto-GPT-TiDB-Serverless-Plugin-main.zip new file mode 100644 index 000000000..1c7b32736 Binary files /dev/null and b/plugins/Auto-GPT-TiDB-Serverless-Plugin-main.zip differ diff --git a/plugins/__PUT_PLUGIN_ZIPS_HERE__ b/plugins/__PUT_PLUGIN_ZIPS_HERE__ new file mode 100644 index 000000000..e69de29bb diff --git a/requirements.txt b/requirements.txt index b22b2d2ad..76dc12370 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,4 +54,23 @@ fschat==0.1.10 llama-index==0.5.27 pymysql unstructured==0.6.3 -pytesseract==0.3.10 \ No newline at end of file +pytesseract==0.3.10 +auto-gpt-plugin-template +pymdown-extensions +mkdocs +requests +gTTS==2.3.1 + +# OpenAI and Generic plugins import +openapi-python-client==0.13.4 + +# Testing dependencies +pytest +asynctest +pytest-asyncio +pytest-benchmark +pytest-cov +pytest-integration +pytest-mock +vcrpy +pytest-recording \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip b/tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip new file mode 100644 index 000000000..00bc1f4f5 Binary files /dev/null and b/tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip differ diff --git a/tests/unit/test_plugins.py b/tests/unit/test_plugins.py new file mode 100644 index 000000000..21dbaaf27 --- /dev/null +++ b/tests/unit/test_plugins.py @@ -0,0 +1,135 @@ +import pytest +import os + + +from pilot.configs.config import Config +from pilot.plugins import ( + denylist_allowlist_check, + inspect_zip_for_modules, + scan_plugins, +) + +PLUGINS_TEST_DIR = "tests/unit/data/test_plugins" +PLUGINS_TEST_DIR_TEMP = "data/test_plugins" +PLUGIN_TEST_ZIP_FILE = "Auto-GPT-Plugin-Test-master.zip" +PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_vicuna/__init__.py" +PLUGIN_TEST_OPENAI = "https://weathergpt.vercel.app/" + +def test_inspect_zip_for_modules(): + current_dir = os.getcwd() + print(current_dir) + result = inspect_zip_for_modules(str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}")) + assert result == [PLUGIN_TEST_INIT_PY] + + +@pytest.fixture +def mock_config_denylist_allowlist_check(): + class MockConfig: + """Mock config object for testing the denylist_allowlist_check function""" + + plugins_denylist = ["BadPlugin"] + plugins_allowlist = ["GoodPlugin"] + authorise_key = "y" + exit_key = "n" + + return MockConfig() + + +def test_denylist_allowlist_check_denylist( + mock_config_denylist_allowlist_check, monkeypatch +): + # Test that the function returns False when the plugin is in the denylist + monkeypatch.setattr("builtins.input", lambda _: "y") + assert not denylist_allowlist_check( + "BadPlugin", mock_config_denylist_allowlist_check + ) + + +def test_denylist_allowlist_check_allowlist( + mock_config_denylist_allowlist_check, monkeypatch +): + # Test that the function returns True when the plugin is in the allowlist + monkeypatch.setattr("builtins.input", lambda _: "y") + assert denylist_allowlist_check("GoodPlugin", mock_config_denylist_allowlist_check) + + +def test_denylist_allowlist_check_user_input_yes( + mock_config_denylist_allowlist_check, monkeypatch +): + # Test that the function returns True when the user inputs "y" + monkeypatch.setattr("builtins.input", lambda _: "y") + assert denylist_allowlist_check( + "UnknownPlugin", mock_config_denylist_allowlist_check + ) + + +def test_denylist_allowlist_check_user_input_no( + mock_config_denylist_allowlist_check, monkeypatch +): + # Test that the function returns False when the user inputs "n" + monkeypatch.setattr("builtins.input", lambda _: "n") + assert not denylist_allowlist_check( + "UnknownPlugin", mock_config_denylist_allowlist_check + ) + + +def test_denylist_allowlist_check_user_input_invalid( + mock_config_denylist_allowlist_check, monkeypatch +): + # Test that the function returns False when the user inputs an invalid value + monkeypatch.setattr("builtins.input", lambda _: "invalid") + assert not denylist_allowlist_check( + "UnknownPlugin", mock_config_denylist_allowlist_check + ) + + +@pytest.fixture +def config_with_plugins(): + """Mock config object for testing the scan_plugins function""" + # Test that the function returns the correct number of plugins + cfg = Config() + cfg.plugins_dir = PLUGINS_TEST_DIR + cfg.plugins_openai = ["https://weathergpt.vercel.app/"] + return cfg + + +@pytest.fixture +def mock_config_openai_plugin(): + """Mock config object for testing the scan_plugins function""" + + class MockConfig: + """Mock config object for testing the scan_plugins function""" + current_dir = os.getcwd() + plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/" + plugins_openai = [PLUGIN_TEST_OPENAI] + plugins_denylist = ["AutoGPTPVicuna"] + plugins_allowlist = [PLUGIN_TEST_OPENAI] + + return MockConfig() + + +def test_scan_plugins_openai(mock_config_openai_plugin): + # Test that the function returns the correct number of plugins + result = scan_plugins(mock_config_openai_plugin, debug=True) + assert len(result) == 1 + + +@pytest.fixture +def mock_config_generic_plugin(): + """Mock config object for testing the scan_plugins function""" + + # Test that the function returns the correct number of plugins + class MockConfig: + current_dir = os.getcwd() + plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/" + plugins_openai = [] + plugins_denylist = [] + plugins_allowlist = ["AutoGPTPVicuna"] + + return MockConfig() + + +def test_scan_plugins_generic(mock_config_generic_plugin): + # Test that the function returns the correct number of plugins + result = scan_plugins(mock_config_generic_plugin, debug=True) + assert len(result) == 1