mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
兼容autogpt plugin load
This commit is contained in:
parent
fa965999e1
commit
f684a7f6f0
3
.idea/.gitignore
vendored
Normal file
3
.idea/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
12
.idea/DB-GPT.iml
Normal file
12
.idea/DB-GPT.iml
Normal file
@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.10 (Auto-GPT)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="GOOGLE" />
|
||||
<option name="myDocStringFormat" value="Google" />
|
||||
</component>
|
||||
</module>
|
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
4
.idea/misc.xml
Normal file
4
.idea/misc.xml
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (Auto-GPT)" project-jdk-type="Python SDK" />
|
||||
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/DB-GPT.iml" filepath="$PROJECT_DIR$/.idea/DB-GPT.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
199
pilot/agent/base_open_ai_plugin.py
Normal file
199
pilot/agent/base_open_ai_plugin.py
Normal file
@ -0,0 +1,199 @@
|
||||
"""Handles loading of plugins."""
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
PromptGenerator = TypeVar("PromptGenerator")
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class BaseOpenAIPlugin(AutoGPTPluginTemplate):
|
||||
"""
|
||||
This is a BaseOpenAIPlugin class for generating Auto-GPT plugins.
|
||||
"""
|
||||
|
||||
def __init__(self, manifests_specs_clients: dict):
|
||||
# super().__init__()
|
||||
self._name = manifests_specs_clients["manifest"]["name_for_model"]
|
||||
self._version = manifests_specs_clients["manifest"]["schema_version"]
|
||||
self._description = manifests_specs_clients["manifest"]["description_for_model"]
|
||||
self._client = manifests_specs_clients["client"]
|
||||
self._manifest = manifests_specs_clients["manifest"]
|
||||
self._openapi_spec = manifests_specs_clients["openapi_spec"]
|
||||
|
||||
def can_handle_on_response(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_response method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_response method."""
|
||||
return False
|
||||
|
||||
def on_response(self, response: str, *args, **kwargs) -> str:
|
||||
"""This method is called when a response is received from the model."""
|
||||
return response
|
||||
|
||||
def can_handle_post_prompt(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_prompt method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_prompt method."""
|
||||
return False
|
||||
|
||||
def post_prompt(self, prompt: PromptGenerator) -> PromptGenerator:
|
||||
"""This method is called just after the generate_prompt is called,
|
||||
but actually before the prompt is generated.
|
||||
Args:
|
||||
prompt (PromptGenerator): The prompt generator.
|
||||
Returns:
|
||||
PromptGenerator: The prompt generator.
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def can_handle_on_planning(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_planning method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_planning method."""
|
||||
return False
|
||||
|
||||
def on_planning(
|
||||
self, prompt: PromptGenerator, messages: List[Message]
|
||||
) -> Optional[str]:
|
||||
"""This method is called before the planning chat completion is done.
|
||||
Args:
|
||||
prompt (PromptGenerator): The prompt generator.
|
||||
messages (List[str]): The list of messages.
|
||||
"""
|
||||
pass
|
||||
|
||||
def can_handle_post_planning(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_planning method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_planning method."""
|
||||
return False
|
||||
|
||||
def post_planning(self, response: str) -> str:
|
||||
"""This method is called after the planning chat completion is done.
|
||||
Args:
|
||||
response (str): The response.
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
return response
|
||||
|
||||
def can_handle_pre_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the pre_instruction method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the pre_instruction method."""
|
||||
return False
|
||||
|
||||
def pre_instruction(self, messages: List[Message]) -> List[Message]:
|
||||
"""This method is called before the instruction chat is done.
|
||||
Args:
|
||||
messages (List[Message]): The list of context messages.
|
||||
Returns:
|
||||
List[Message]: The resulting list of messages.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def can_handle_on_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the on_instruction method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the on_instruction method."""
|
||||
return False
|
||||
|
||||
def on_instruction(self, messages: List[Message]) -> Optional[str]:
|
||||
"""This method is called when the instruction chat is done.
|
||||
Args:
|
||||
messages (List[Message]): The list of context messages.
|
||||
Returns:
|
||||
Optional[str]: The resulting message.
|
||||
"""
|
||||
pass
|
||||
|
||||
def can_handle_post_instruction(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_instruction method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_instruction method."""
|
||||
return False
|
||||
|
||||
def post_instruction(self, response: str) -> str:
|
||||
"""This method is called after the instruction chat is done.
|
||||
Args:
|
||||
response (str): The response.
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
return response
|
||||
|
||||
def can_handle_pre_command(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the pre_command method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the pre_command method."""
|
||||
return False
|
||||
|
||||
def pre_command(
|
||||
self, command_name: str, arguments: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""This method is called before the command is executed.
|
||||
Args:
|
||||
command_name (str): The command name.
|
||||
arguments (Dict[str, Any]): The arguments.
|
||||
Returns:
|
||||
Tuple[str, Dict[str, Any]]: The command name and the arguments.
|
||||
"""
|
||||
return command_name, arguments
|
||||
|
||||
def can_handle_post_command(self) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the post_command method.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the post_command method."""
|
||||
return False
|
||||
|
||||
def post_command(self, command_name: str, response: str) -> str:
|
||||
"""This method is called after the command is executed.
|
||||
Args:
|
||||
command_name (str): The command name.
|
||||
response (str): The response.
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
return response
|
||||
|
||||
def can_handle_chat_completion(
|
||||
self, messages: Dict[Any, Any], model: str, temperature: float, max_tokens: int
|
||||
) -> bool:
|
||||
"""This method is called to check that the plugin can
|
||||
handle the chat_completion method.
|
||||
Args:
|
||||
messages (List[Message]): The messages.
|
||||
model (str): The model name.
|
||||
temperature (float): The temperature.
|
||||
max_tokens (int): The max tokens.
|
||||
Returns:
|
||||
bool: True if the plugin can handle the chat_completion method."""
|
||||
return False
|
||||
|
||||
def handle_chat_completion(
|
||||
self, messages: List[Message], model: str, temperature: float, max_tokens: int
|
||||
) -> str:
|
||||
"""This method is called when the chat completion is done.
|
||||
Args:
|
||||
messages (List[Message]): The messages.
|
||||
model (str): The model name.
|
||||
temperature (float): The temperature.
|
||||
max_tokens (int): The max tokens.
|
||||
Returns:
|
||||
str: The resulting response.
|
||||
"""
|
||||
pass
|
@ -29,7 +29,14 @@ class Config(metaclass=Singleton):
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36"
|
||||
" (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
|
||||
)
|
||||
|
||||
|
||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
self.elevenlabs_voice_1_id = os.getenv("ELEVENLABS_VOICE_1_ID")
|
||||
self.elevenlabs_voice_2_id = os.getenv("ELEVENLABS_VOICE_2_ID")
|
||||
|
||||
self.use_mac_os_tts = False
|
||||
self.use_mac_os_tts = os.getenv("USE_MAC_OS_TTS")
|
||||
|
||||
# milvus or zilliz cloud configuration
|
||||
self.milvus_addr = os.getenv("MILVUS_ADDR", "localhost:19530")
|
||||
self.milvus_username = os.getenv("MILVUS_USERNAME")
|
||||
@ -37,14 +44,20 @@ class Config(metaclass=Singleton):
|
||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||
|
||||
self.plugins_dir = os.getenv("PLUGINS_DIR", "plugins")
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins_openai = []
|
||||
|
||||
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS")
|
||||
if plugins_allowlist:
|
||||
self.plugins_allowlist = plugins_allowlist.split(",")
|
||||
else:
|
||||
self.plugins_allowlist = []
|
||||
self.plugins_allowlist = []
|
||||
|
||||
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
|
||||
plugins_denylist = os.getenv("DENYLISTED_PLUGINS")
|
||||
if plugins_denylist:
|
||||
self.plugins_denylist = plugins_denylist.split(",")
|
||||
else:
|
||||
self.plugins_denylist = []
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
|
0
pilot/log/__init__.py
Normal file
0
pilot/log/__init__.py
Normal file
20
pilot/log/json_handler.py
Normal file
20
pilot/log/json_handler.py
Normal file
@ -0,0 +1,20 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
class JsonFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename, mode="a", encoding=None, delay=False):
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
|
||||
def emit(self, record):
|
||||
json_data = json.loads(self.format(record))
|
||||
with open(self.baseFilename, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
return record.msg
|
287
pilot/logs.py
Normal file
287
pilot/logs.py
Normal file
@ -0,0 +1,287 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from logging import LogRecord
|
||||
from typing import Any
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
||||
from pilot.log.json_handler import JsonFileHandler, JsonFormatter
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.speech import say_text
|
||||
|
||||
|
||||
class Logger(metaclass=Singleton):
|
||||
"""
|
||||
Logger that handle titles in different colors.
|
||||
Outputs logs in console, activity.log, and errors.log
|
||||
For console handler: simulates typing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# create log directory if it doesn't exist
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
log_file = "activity.log"
|
||||
error_file = "error.log"
|
||||
|
||||
console_formatter = DbGptFormatter("%(title_color)s %(message)s")
|
||||
|
||||
# Create a handler for console which simulate typing
|
||||
self.typing_console_handler = TypingConsoleHandler()
|
||||
self.typing_console_handler.setLevel(logging.INFO)
|
||||
self.typing_console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Create a handler for console without typing simulation
|
||||
self.console_handler = ConsoleHandler()
|
||||
self.console_handler.setLevel(logging.DEBUG)
|
||||
self.console_handler.setFormatter(console_formatter)
|
||||
|
||||
# Info handler in activity.log
|
||||
self.file_handler = logging.FileHandler(
|
||||
os.path.join(log_dir, log_file), "a", "utf-8"
|
||||
)
|
||||
self.file_handler.setLevel(logging.DEBUG)
|
||||
info_formatter = DbGptFormatter(
|
||||
"%(asctime)s %(levelname)s %(title)s %(message_no_color)s"
|
||||
)
|
||||
self.file_handler.setFormatter(info_formatter)
|
||||
|
||||
# Error handler error.log
|
||||
error_handler = logging.FileHandler(
|
||||
os.path.join(log_dir, error_file), "a", "utf-8"
|
||||
)
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
error_formatter = DbGptFormatter(
|
||||
"%(asctime)s %(levelname)s %(module)s:%(funcName)s:%(lineno)d %(title)s"
|
||||
" %(message_no_color)s"
|
||||
)
|
||||
error_handler.setFormatter(error_formatter)
|
||||
|
||||
self.typing_logger = logging.getLogger("TYPER")
|
||||
self.typing_logger.addHandler(self.typing_console_handler)
|
||||
self.typing_logger.addHandler(self.file_handler)
|
||||
self.typing_logger.addHandler(error_handler)
|
||||
self.typing_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.logger = logging.getLogger("LOGGER")
|
||||
self.logger.addHandler(self.console_handler)
|
||||
self.logger.addHandler(self.file_handler)
|
||||
self.logger.addHandler(error_handler)
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.json_logger = logging.getLogger("JSON_LOGGER")
|
||||
self.json_logger.addHandler(self.file_handler)
|
||||
self.json_logger.addHandler(error_handler)
|
||||
self.json_logger.setLevel(logging.DEBUG)
|
||||
|
||||
self.speak_mode = False
|
||||
self.chat_plugins = []
|
||||
|
||||
def typewriter_log(
|
||||
self, title="", title_color="", content="", speak_text=False, level=logging.INFO
|
||||
):
|
||||
if speak_text and self.speak_mode:
|
||||
say_text(f"{title}. {content}")
|
||||
|
||||
for plugin in self.chat_plugins:
|
||||
plugin.report(f"{title}. {content}")
|
||||
|
||||
if content:
|
||||
if isinstance(content, list):
|
||||
content = " ".join(content)
|
||||
else:
|
||||
content = ""
|
||||
|
||||
self.typing_logger.log(
|
||||
level, content, extra={"title": title, "color": title_color}
|
||||
)
|
||||
|
||||
def debug(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.DEBUG)
|
||||
|
||||
def info(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.INFO)
|
||||
|
||||
def warn(
|
||||
self,
|
||||
message,
|
||||
title="",
|
||||
title_color="",
|
||||
):
|
||||
self._log(title, title_color, message, logging.WARN)
|
||||
|
||||
def error(self, title, message=""):
|
||||
self._log(title, Fore.RED, message, logging.ERROR)
|
||||
|
||||
def _log(
|
||||
self,
|
||||
title: str = "",
|
||||
title_color: str = "",
|
||||
message: str = "",
|
||||
level=logging.INFO,
|
||||
):
|
||||
if message:
|
||||
if isinstance(message, list):
|
||||
message = " ".join(message)
|
||||
self.logger.log(
|
||||
level, message, extra={"title": str(title), "color": str(title_color)}
|
||||
)
|
||||
|
||||
def set_level(self, level):
|
||||
self.logger.setLevel(level)
|
||||
self.typing_logger.setLevel(level)
|
||||
|
||||
def double_check(self, additionalText=None):
|
||||
if not additionalText:
|
||||
additionalText = (
|
||||
"Please ensure you've setup and configured everything"
|
||||
" correctly. Read https://github.com/Torantulino/Auto-GPT#readme to "
|
||||
"double check. You can also create a github issue or join the discord"
|
||||
" and ask there!"
|
||||
)
|
||||
|
||||
self.typewriter_log("DOUBLE CHECK CONFIGURATION", Fore.YELLOW, additionalText)
|
||||
|
||||
def log_json(self, data: Any, file_name: str) -> None:
|
||||
# Define log directory
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
|
||||
# Create a handler for JSON files
|
||||
json_file_path = os.path.join(log_dir, file_name)
|
||||
json_data_handler = JsonFileHandler(json_file_path)
|
||||
json_data_handler.setFormatter(JsonFormatter())
|
||||
|
||||
# Log the JSON data using the custom file handler
|
||||
self.json_logger.addHandler(json_data_handler)
|
||||
self.json_logger.debug(data)
|
||||
self.json_logger.removeHandler(json_data_handler)
|
||||
|
||||
def get_log_directory(self):
|
||||
this_files_dir_path = os.path.dirname(__file__)
|
||||
log_dir = os.path.join(this_files_dir_path, "../logs")
|
||||
return os.path.abspath(log_dir)
|
||||
|
||||
"""
|
||||
Output stream to console using simulated typing
|
||||
"""
|
||||
|
||||
class TypingConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record):
|
||||
min_typing_speed = 0.05
|
||||
max_typing_speed = 0.01
|
||||
|
||||
msg = self.format(record)
|
||||
try:
|
||||
words = msg.split()
|
||||
for i, word in enumerate(words):
|
||||
print(word, end="", flush=True)
|
||||
if i < len(words) - 1:
|
||||
print(" ", end="", flush=True)
|
||||
typing_speed = random.uniform(min_typing_speed, max_typing_speed)
|
||||
time.sleep(typing_speed)
|
||||
# type faster after each word
|
||||
min_typing_speed = min_typing_speed * 0.95
|
||||
max_typing_speed = max_typing_speed * 0.95
|
||||
print()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
class ConsoleHandler(logging.StreamHandler):
|
||||
def emit(self, record) -> None:
|
||||
msg = self.format(record)
|
||||
try:
|
||||
print(msg)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class DbGptFormatter(logging.Formatter):
|
||||
"""
|
||||
Allows to handle custom placeholders 'title_color' and 'message_no_color'.
|
||||
To use this formatter, make sure to pass 'color', 'title' as log extras.
|
||||
"""
|
||||
|
||||
def format(self, record: LogRecord) -> str:
|
||||
if hasattr(record, "color"):
|
||||
record.title_color = (
|
||||
getattr(record, "color")
|
||||
+ getattr(record, "title", "")
|
||||
+ " "
|
||||
+ Style.RESET_ALL
|
||||
)
|
||||
else:
|
||||
record.title_color = getattr(record, "title", "")
|
||||
|
||||
# Add this line to set 'title' to an empty string if it doesn't exist
|
||||
record.title = getattr(record, "title", "")
|
||||
|
||||
if hasattr(record, "msg"):
|
||||
record.message_no_color = remove_color_codes(getattr(record, "msg"))
|
||||
else:
|
||||
record.message_no_color = ""
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def remove_color_codes(s: str) -> str:
|
||||
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||||
return ansi_escape.sub("", s)
|
||||
|
||||
|
||||
logger = Logger()
|
||||
|
||||
|
||||
def print_assistant_thoughts(
|
||||
ai_name: object,
|
||||
assistant_reply_json_valid: object,
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
assistant_thoughts_reasoning = None
|
||||
assistant_thoughts_plan = None
|
||||
assistant_thoughts_speak = None
|
||||
assistant_thoughts_criticism = None
|
||||
|
||||
assistant_thoughts = assistant_reply_json_valid.get("thoughts", {})
|
||||
assistant_thoughts_text = assistant_thoughts.get("text")
|
||||
if assistant_thoughts:
|
||||
assistant_thoughts_reasoning = assistant_thoughts.get("reasoning")
|
||||
assistant_thoughts_plan = assistant_thoughts.get("plan")
|
||||
assistant_thoughts_criticism = assistant_thoughts.get("criticism")
|
||||
assistant_thoughts_speak = assistant_thoughts.get("speak")
|
||||
logger.typewriter_log(
|
||||
f"{ai_name.upper()} THOUGHTS:", Fore.YELLOW, f"{assistant_thoughts_text}"
|
||||
)
|
||||
logger.typewriter_log("REASONING:", Fore.YELLOW, f"{assistant_thoughts_reasoning}")
|
||||
if assistant_thoughts_plan:
|
||||
logger.typewriter_log("PLAN:", Fore.YELLOW, "")
|
||||
# If it's a list, join it into a string
|
||||
if isinstance(assistant_thoughts_plan, list):
|
||||
assistant_thoughts_plan = "\n".join(assistant_thoughts_plan)
|
||||
elif isinstance(assistant_thoughts_plan, dict):
|
||||
assistant_thoughts_plan = str(assistant_thoughts_plan)
|
||||
|
||||
# Split the input_string using the newline character and dashes
|
||||
lines = assistant_thoughts_plan.split("\n")
|
||||
for line in lines:
|
||||
line = line.lstrip("- ")
|
||||
logger.typewriter_log("- ", Fore.GREEN, line.strip())
|
||||
logger.typewriter_log("CRITICISM:", Fore.YELLOW, f"{assistant_thoughts_criticism}")
|
||||
# Speak the assistant's thoughts
|
||||
if speak_mode and assistant_thoughts_speak:
|
||||
say_text(assistant_thoughts_speak)
|
275
pilot/plugins.py
Normal file
275
pilot/plugins.py
Normal file
@ -0,0 +1,275 @@
|
||||
"""加载组件"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
from zipimport import zipimporter
|
||||
|
||||
import openapi_python_client
|
||||
import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from openapi_python_client.cli import Config as OpenAPIConfig
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
from pilot.agent.base_open_ai_plugin import BaseOpenAIPlugin
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
"""
|
||||
加载zip文件的插件,完全兼容Auto_gpt_plugin
|
||||
|
||||
Args:
|
||||
zip_path (str): Path to the zipfile.
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of module names found or empty list if none were found.
|
||||
"""
|
||||
result = []
|
||||
with zipfile.ZipFile(zip_path, "r") as zfile:
|
||||
for name in zfile.namelist():
|
||||
if name.endswith("__init__.py") and not name.startswith("__MACOSX"):
|
||||
logger.debug(f"Found module '{name}' in the zipfile at: {name}")
|
||||
result.append(name)
|
||||
if len(result) == 0:
|
||||
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
|
||||
return result
|
||||
|
||||
def write_dict_to_json_file(data: dict, file_path: str) -> None:
|
||||
"""
|
||||
Write a dictionary to a JSON file.
|
||||
Args:
|
||||
data (dict): Dictionary to write.
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, indent=4)
|
||||
|
||||
|
||||
|
||||
def fetch_openai_plugins_manifest_and_spec(cfg: Config) -> dict:
|
||||
"""
|
||||
Fetch the manifest for a list of OpenAI plugins.
|
||||
Args:
|
||||
urls (List): List of URLs to fetch.
|
||||
Returns:
|
||||
dict: per url dictionary of manifest and spec.
|
||||
"""
|
||||
# TODO add directory scan
|
||||
manifests = {}
|
||||
for url in cfg.plugins_openai:
|
||||
openai_plugin_client_dir = f"{cfg.plugins_dir}/openai/{urlparse(url).netloc}"
|
||||
create_directory_if_not_exists(openai_plugin_client_dir)
|
||||
if not os.path.exists(f"{openai_plugin_client_dir}/ai-plugin.json"):
|
||||
try:
|
||||
response = requests.get(f"{url}/.well-known/ai-plugin.json")
|
||||
if response.status_code == 200:
|
||||
manifest = response.json()
|
||||
if manifest["schema_version"] != "v1":
|
||||
logger.warn(
|
||||
f"Unsupported manifest version: {manifest['schem_version']} for {url}"
|
||||
)
|
||||
continue
|
||||
if manifest["api"]["type"] != "openapi":
|
||||
logger.warn(
|
||||
f"Unsupported API type: {manifest['api']['type']} for {url}"
|
||||
)
|
||||
continue
|
||||
write_dict_to_json_file(
|
||||
manifest, f"{openai_plugin_client_dir}/ai-plugin.json"
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
f"Failed to fetch manifest for {url}: {response.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warn(f"Error while requesting manifest from {url}: {e}")
|
||||
else:
|
||||
logger.info(f"Manifest for {url} already exists")
|
||||
manifest = json.load(open(f"{openai_plugin_client_dir}/ai-plugin.json"))
|
||||
if not os.path.exists(f"{openai_plugin_client_dir}/openapi.json"):
|
||||
openapi_spec = openapi_python_client._get_document(
|
||||
url=manifest["api"]["url"], path=None, timeout=5
|
||||
)
|
||||
write_dict_to_json_file(
|
||||
openapi_spec, f"{openai_plugin_client_dir}/openapi.json"
|
||||
)
|
||||
else:
|
||||
logger.info(f"OpenAPI spec for {url} already exists")
|
||||
openapi_spec = json.load(open(f"{openai_plugin_client_dir}/openapi.json"))
|
||||
manifests[url] = {"manifest": manifest, "openapi_spec": openapi_spec}
|
||||
return manifests
|
||||
|
||||
|
||||
|
||||
def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||
"""
|
||||
Create a directory if it does not exist.
|
||||
Args:
|
||||
directory_path (str): Path to the directory.
|
||||
Returns:
|
||||
bool: True if the directory was created, else False.
|
||||
"""
|
||||
if not os.path.exists(directory_path):
|
||||
try:
|
||||
os.makedirs(directory_path)
|
||||
logger.debug(f"Created directory: {directory_path}")
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warn(f"Error creating directory {directory_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"Directory {directory_path} already exists")
|
||||
return True
|
||||
|
||||
|
||||
def initialize_openai_plugins(
|
||||
manifests_specs: dict, cfg: Config, debug: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Initialize OpenAI plugins.
|
||||
Args:
|
||||
manifests_specs (dict): per url dictionary of manifest and spec.
|
||||
cfg (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
Returns:
|
||||
dict: per url dictionary of manifest, spec and client.
|
||||
"""
|
||||
openai_plugins_dir = f"{cfg.plugins_dir}/openai"
|
||||
if create_directory_if_not_exists(openai_plugins_dir):
|
||||
for url, manifest_spec in manifests_specs.items():
|
||||
openai_plugin_client_dir = f"{openai_plugins_dir}/{urlparse(url).hostname}"
|
||||
_meta_option = (openapi_python_client.MetaType.SETUP,)
|
||||
_config = OpenAPIConfig(
|
||||
**{
|
||||
"project_name_override": "client",
|
||||
"package_name_override": "client",
|
||||
}
|
||||
)
|
||||
prev_cwd = Path.cwd()
|
||||
os.chdir(openai_plugin_client_dir)
|
||||
Path("ai-plugin.json")
|
||||
if not os.path.exists("client"):
|
||||
client_results = openapi_python_client.create_new_client(
|
||||
url=manifest_spec["manifest"]["api"]["url"],
|
||||
path=None,
|
||||
meta=_meta_option,
|
||||
config=_config,
|
||||
)
|
||||
if client_results:
|
||||
logger.warn(
|
||||
f"Error creating OpenAPI client: {client_results[0].header} \n"
|
||||
f" details: {client_results[0].detail}"
|
||||
)
|
||||
continue
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"client", "client/client/client.py"
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
client = module.Client(base_url=url)
|
||||
os.chdir(prev_cwd)
|
||||
manifest_spec["client"] = client
|
||||
return manifests_specs
|
||||
|
||||
|
||||
def instantiate_openai_plugin_clients(
|
||||
manifests_specs_clients: dict, cfg: Config, debug: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Instantiates BaseOpenAIPlugin instances for each OpenAI plugin.
|
||||
Args:
|
||||
manifests_specs_clients (dict): per url dictionary of manifest, spec and client.
|
||||
cfg (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
Returns:
|
||||
plugins (dict): per url dictionary of BaseOpenAIPlugin instances.
|
||||
|
||||
"""
|
||||
plugins = {}
|
||||
for url, manifest_spec_client in manifests_specs_clients.items():
|
||||
plugins[url] = BaseOpenAIPlugin(manifest_spec_client)
|
||||
return plugins
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
cfg (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Path]]: List of plugins.
|
||||
"""
|
||||
loaded_plugins = []
|
||||
# Generic plugins
|
||||
plugins_path_path = Path(cfg.plugins_dir)
|
||||
|
||||
logger.debug(f"Allowlisted Plugins: {cfg.plugins_allowlist}")
|
||||
logger.debug(f"Denylisted Plugins: {cfg.plugins_denylist}")
|
||||
|
||||
for plugin in plugins_path_path.glob("*.zip"):
|
||||
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
||||
for module in moduleList:
|
||||
plugin = Path(plugin)
|
||||
module = Path(module)
|
||||
logger.debug(f"Plugin: {plugin} Module: {module}")
|
||||
zipped_package = zipimporter(str(plugin))
|
||||
zipped_module = zipped_package.load_module(str(module.parent))
|
||||
for key in dir(zipped_module):
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
a_module = getattr(zipped_module, key)
|
||||
a_keys = dir(a_module)
|
||||
if (
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
):
|
||||
loaded_plugins.append(a_module())
|
||||
# OpenAI plugins
|
||||
if cfg.plugins_openai:
|
||||
manifests_specs = fetch_openai_plugins_manifest_and_spec(cfg)
|
||||
if manifests_specs.keys():
|
||||
manifests_specs_clients = initialize_openai_plugins(
|
||||
manifests_specs, cfg, debug
|
||||
)
|
||||
for url, openai_plugin_meta in manifests_specs_clients.items():
|
||||
if denylist_allowlist_check(url, cfg):
|
||||
plugin = BaseOpenAIPlugin(openai_plugin_meta)
|
||||
loaded_plugins.append(plugin)
|
||||
|
||||
if loaded_plugins:
|
||||
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||
for plugin in loaded_plugins:
|
||||
logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}")
|
||||
return loaded_plugins
|
||||
|
||||
|
||||
def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
||||
"""Check if the plugin is in the allowlist or denylist.
|
||||
|
||||
Args:
|
||||
plugin_name (str): Name of the plugin.
|
||||
cfg (Config): Config object.
|
||||
|
||||
Returns:
|
||||
True or False
|
||||
"""
|
||||
logger.debug(f"Checking if plugin {plugin_name} should be loaded")
|
||||
if plugin_name in cfg.plugins_denylist:
|
||||
logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.")
|
||||
return False
|
||||
if plugin_name in cfg.plugins_allowlist:
|
||||
logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.")
|
||||
return True
|
||||
ack = input(
|
||||
f"WARNING: Plugin {plugin_name} found. But not in the"
|
||||
f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): "
|
||||
)
|
||||
return ack.lower() == cfg.authorise_key
|
3
pilot/speech/__init__.py
Normal file
3
pilot/speech/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
__all__ = ["say_text"]
|
50
pilot/speech/base.py
Normal file
50
pilot/speech/base.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""Base class for all voice classes."""
|
||||
import abc
|
||||
from threading import Lock
|
||||
|
||||
from pilot.singleton import AbstractSingleton
|
||||
|
||||
|
||||
class VoiceBase(AbstractSingleton):
|
||||
"""
|
||||
Base class for all voice classes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the voice class.
|
||||
"""
|
||||
self._url = None
|
||||
self._headers = None
|
||||
self._api_key = None
|
||||
self._voices = []
|
||||
self._mutex = Lock()
|
||||
self._setup()
|
||||
|
||||
def say(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""
|
||||
Say the given text.
|
||||
|
||||
Args:
|
||||
text (str): The text to say.
|
||||
voice_index (int): The index of the voice to use.
|
||||
"""
|
||||
with self._mutex:
|
||||
return self._speech(text, voice_index)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _setup(self) -> None:
|
||||
"""
|
||||
Setup the voices, API key, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""
|
||||
Play the given text.
|
||||
|
||||
Args:
|
||||
text (str): The text to play.
|
||||
"""
|
||||
pass
|
43
pilot/speech/brian.py
Normal file
43
pilot/speech/brian.py
Normal file
@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
|
||||
class BrianSpeech(VoiceBase):
|
||||
"""Brian speech module for autogpt"""
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Setup the voices, API key, etc."""
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
"""Speak text using Brian with the streamelements API
|
||||
|
||||
Args:
|
||||
text (str): The text to speak
|
||||
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
tts_url = (
|
||||
f"https://api.streamelements.com/kappa/v2/speech?voice=Brian&text={text}"
|
||||
)
|
||||
response = requests.get(tts_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open("speech.mp3", "wb") as f:
|
||||
f.write(response.content)
|
||||
playsound("speech.mp3")
|
||||
os.remove("speech.mp3")
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
"Request failed with status code: %s, response content: %s",
|
||||
response.status_code,
|
||||
response.content,
|
||||
)
|
||||
return False
|
88
pilot/speech/eleven_labs.py
Normal file
88
pilot/speech/eleven_labs.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""ElevenLabs speech module"""
|
||||
import os
|
||||
|
||||
import requests
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
PLACEHOLDERS = {"your-voice-id"}
|
||||
|
||||
|
||||
class ElevenLabsSpeech(VoiceBase):
|
||||
"""ElevenLabs speech class"""
|
||||
|
||||
def _setup(self) -> None:
|
||||
"""Set up the voices, API key, etc.
|
||||
|
||||
Returns:
|
||||
None: None
|
||||
"""
|
||||
|
||||
cfg = Config()
|
||||
default_voices = ["ErXwobaYiN019PkySvjV", "EXAVITQu4vr4xnSDxMaL"]
|
||||
voice_options = {
|
||||
"Rachel": "21m00Tcm4TlvDq8ikWAM",
|
||||
"Domi": "AZnzlk1XvdvUeBnXmlld",
|
||||
"Bella": "EXAVITQu4vr4xnSDxMaL",
|
||||
"Antoni": "ErXwobaYiN019PkySvjV",
|
||||
"Elli": "MF3mGyEYCl7XYWbV9V6O",
|
||||
"Josh": "TxGEqnHWrfWFTfGW9XjX",
|
||||
"Arnold": "VR6AewLTigWG4xSOukaG",
|
||||
"Adam": "pNInz6obpgDQGcFmaJgB",
|
||||
"Sam": "yoZ06aMxZJJ28mfd3POQ",
|
||||
}
|
||||
self._headers = {
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": cfg.elevenlabs_api_key,
|
||||
}
|
||||
self._voices = default_voices.copy()
|
||||
if cfg.elevenlabs_voice_1_id in voice_options:
|
||||
cfg.elevenlabs_voice_1_id = voice_options[cfg.elevenlabs_voice_1_id]
|
||||
if cfg.elevenlabs_voice_2_id in voice_options:
|
||||
cfg.elevenlabs_voice_2_id = voice_options[cfg.elevenlabs_voice_2_id]
|
||||
self._use_custom_voice(cfg.elevenlabs_voice_1_id, 0)
|
||||
self._use_custom_voice(cfg.elevenlabs_voice_2_id, 1)
|
||||
|
||||
def _use_custom_voice(self, voice, voice_index) -> None:
|
||||
"""Use a custom voice if provided and not a placeholder
|
||||
|
||||
Args:
|
||||
voice (str): The voice ID
|
||||
voice_index (int): The voice index
|
||||
|
||||
Returns:
|
||||
None: None
|
||||
"""
|
||||
# Placeholder values that should be treated as empty
|
||||
if voice and voice not in PLACEHOLDERS:
|
||||
self._voices[voice_index] = voice
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""Speak text using elevenlabs.io's API
|
||||
|
||||
Args:
|
||||
text (str): The text to speak
|
||||
voice_index (int, optional): The voice to use. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from autogpt.logs import logger
|
||||
|
||||
tts_url = (
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{self._voices[voice_index]}"
|
||||
)
|
||||
response = requests.post(tts_url, headers=self._headers, json={"text": text})
|
||||
|
||||
if response.status_code == 200:
|
||||
with open("speech.mpeg", "wb") as f:
|
||||
f.write(response.content)
|
||||
playsound("speech.mpeg", True)
|
||||
os.remove("speech.mpeg")
|
||||
return True
|
||||
else:
|
||||
logger.warn("Request failed with status code:", response.status_code)
|
||||
logger.info("Response content:", response.content)
|
||||
return False
|
22
pilot/speech/gtts.py
Normal file
22
pilot/speech/gtts.py
Normal file
@ -0,0 +1,22 @@
|
||||
""" GTTS Voice. """
|
||||
import os
|
||||
|
||||
import gtts
|
||||
from playsound import playsound
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
|
||||
class GTTSVoice(VoiceBase):
|
||||
"""GTTS Voice."""
|
||||
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, _: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
tts = gtts.gTTS(text)
|
||||
tts.save("speech.mp3")
|
||||
playsound("speech.mp3", True)
|
||||
os.remove("speech.mp3")
|
||||
return True
|
21
pilot/speech/macos_tts.py
Normal file
21
pilot/speech/macos_tts.py
Normal file
@ -0,0 +1,21 @@
|
||||
""" MacOS TTS Voice. """
|
||||
import os
|
||||
|
||||
from pilot.speech.base import VoiceBase
|
||||
|
||||
|
||||
class MacOSTTS(VoiceBase):
|
||||
"""MacOS TTS Voice."""
|
||||
|
||||
def _setup(self) -> None:
|
||||
pass
|
||||
|
||||
def _speech(self, text: str, voice_index: int = 0) -> bool:
|
||||
"""Play the given text."""
|
||||
if voice_index == 0:
|
||||
os.system(f'say "{text}"')
|
||||
elif voice_index == 1:
|
||||
os.system(f'say -v "Ava (Premium)" "{text}"')
|
||||
else:
|
||||
os.system(f'say -v Samantha "{text}"')
|
||||
return True
|
46
pilot/speech/say.py
Normal file
46
pilot/speech/say.py
Normal file
@ -0,0 +1,46 @@
|
||||
""" Text to speech module """
|
||||
import threading
|
||||
from threading import Semaphore
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.speech.base import VoiceBase
|
||||
from pilot.speech.brian import BrianSpeech
|
||||
from pilot.speech.eleven_labs import ElevenLabsSpeech
|
||||
from pilot.speech.gtts import GTTSVoice
|
||||
from pilot.speech.macos_tts import MacOSTTS
|
||||
|
||||
_QUEUE_SEMAPHORE = Semaphore(
|
||||
1
|
||||
) # The amount of sounds to queue before blocking the main thread
|
||||
|
||||
|
||||
def say_text(text: str, voice_index: int = 0) -> None:
|
||||
"""Speak the given text using the given voice index"""
|
||||
cfg = Config()
|
||||
default_voice_engine, voice_engine = _get_voice_engine(cfg)
|
||||
|
||||
def speak() -> None:
|
||||
success = voice_engine.say(text, voice_index)
|
||||
if not success:
|
||||
default_voice_engine.say(text)
|
||||
|
||||
_QUEUE_SEMAPHORE.release()
|
||||
|
||||
_QUEUE_SEMAPHORE.acquire(True)
|
||||
thread = threading.Thread(target=speak)
|
||||
thread.start()
|
||||
|
||||
|
||||
def _get_voice_engine(config: Config) -> tuple[VoiceBase, VoiceBase]:
|
||||
"""Get the voice engine to use for the given configuration"""
|
||||
default_voice_engine = GTTSVoice()
|
||||
if config.elevenlabs_api_key:
|
||||
voice_engine = ElevenLabsSpeech()
|
||||
elif config.use_mac_os_tts == "True":
|
||||
voice_engine = MacOSTTS()
|
||||
elif config.use_brian_tts == "True":
|
||||
voice_engine = BrianSpeech()
|
||||
else:
|
||||
voice_engine = GTTSVoice()
|
||||
|
||||
return default_voice_engine, voice_engine
|
BIN
plugins/Auto-GPT-TiDB-Serverless-Plugin-main.zip
Normal file
BIN
plugins/Auto-GPT-TiDB-Serverless-Plugin-main.zip
Normal file
Binary file not shown.
0
plugins/__PUT_PLUGIN_ZIPS_HERE__
Normal file
0
plugins/__PUT_PLUGIN_ZIPS_HERE__
Normal file
@ -54,4 +54,23 @@ fschat==0.1.10
|
||||
llama-index==0.5.27
|
||||
pymysql
|
||||
unstructured==0.6.3
|
||||
pytesseract==0.3.10
|
||||
pytesseract==0.3.10
|
||||
auto-gpt-plugin-template
|
||||
pymdown-extensions
|
||||
mkdocs
|
||||
requests
|
||||
gTTS==2.3.1
|
||||
|
||||
# OpenAI and Generic plugins import
|
||||
openapi-python-client==0.13.4
|
||||
|
||||
# Testing dependencies
|
||||
pytest
|
||||
asynctest
|
||||
pytest-asyncio
|
||||
pytest-benchmark
|
||||
pytest-cov
|
||||
pytest-integration
|
||||
pytest-mock
|
||||
vcrpy
|
||||
pytest-recording
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
BIN
tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip
Normal file
BIN
tests/unit/data/test_plugins/Auto-GPT-Plugin-Test-master.zip
Normal file
Binary file not shown.
135
tests/unit/test_plugins.py
Normal file
135
tests/unit/test_plugins.py
Normal file
@ -0,0 +1,135 @@
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.plugins import (
|
||||
denylist_allowlist_check,
|
||||
inspect_zip_for_modules,
|
||||
scan_plugins,
|
||||
)
|
||||
|
||||
PLUGINS_TEST_DIR = "tests/unit/data/test_plugins"
|
||||
PLUGINS_TEST_DIR_TEMP = "data/test_plugins"
|
||||
PLUGIN_TEST_ZIP_FILE = "Auto-GPT-Plugin-Test-master.zip"
|
||||
PLUGIN_TEST_INIT_PY = "Auto-GPT-Plugin-Test-master/src/auto_gpt_vicuna/__init__.py"
|
||||
PLUGIN_TEST_OPENAI = "https://weathergpt.vercel.app/"
|
||||
|
||||
def test_inspect_zip_for_modules():
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
result = inspect_zip_for_modules(str(f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/{PLUGIN_TEST_ZIP_FILE}"))
|
||||
assert result == [PLUGIN_TEST_INIT_PY]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_denylist_allowlist_check():
|
||||
class MockConfig:
|
||||
"""Mock config object for testing the denylist_allowlist_check function"""
|
||||
|
||||
plugins_denylist = ["BadPlugin"]
|
||||
plugins_allowlist = ["GoodPlugin"]
|
||||
authorise_key = "y"
|
||||
exit_key = "n"
|
||||
|
||||
return MockConfig()
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_denylist(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the plugin is in the denylist
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert not denylist_allowlist_check(
|
||||
"BadPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_allowlist(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns True when the plugin is in the allowlist
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert denylist_allowlist_check("GoodPlugin", mock_config_denylist_allowlist_check)
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_user_input_yes(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns True when the user inputs "y"
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
assert denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_user_input_no(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the user inputs "n"
|
||||
monkeypatch.setattr("builtins.input", lambda _: "n")
|
||||
assert not denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
def test_denylist_allowlist_check_user_input_invalid(
|
||||
mock_config_denylist_allowlist_check, monkeypatch
|
||||
):
|
||||
# Test that the function returns False when the user inputs an invalid value
|
||||
monkeypatch.setattr("builtins.input", lambda _: "invalid")
|
||||
assert not denylist_allowlist_check(
|
||||
"UnknownPlugin", mock_config_denylist_allowlist_check
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_with_plugins():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
# Test that the function returns the correct number of plugins
|
||||
cfg = Config()
|
||||
cfg.plugins_dir = PLUGINS_TEST_DIR
|
||||
cfg.plugins_openai = ["https://weathergpt.vercel.app/"]
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_openai_plugin():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
current_dir = os.getcwd()
|
||||
plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/"
|
||||
plugins_openai = [PLUGIN_TEST_OPENAI]
|
||||
plugins_denylist = ["AutoGPTPVicuna"]
|
||||
plugins_allowlist = [PLUGIN_TEST_OPENAI]
|
||||
|
||||
return MockConfig()
|
||||
|
||||
|
||||
def test_scan_plugins_openai(mock_config_openai_plugin):
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(mock_config_openai_plugin, debug=True)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_generic_plugin():
|
||||
"""Mock config object for testing the scan_plugins function"""
|
||||
|
||||
# Test that the function returns the correct number of plugins
|
||||
class MockConfig:
|
||||
current_dir = os.getcwd()
|
||||
plugins_dir = f"{current_dir}/{PLUGINS_TEST_DIR_TEMP}/"
|
||||
plugins_openai = []
|
||||
plugins_denylist = []
|
||||
plugins_allowlist = ["AutoGPTPVicuna"]
|
||||
|
||||
return MockConfig()
|
||||
|
||||
|
||||
def test_scan_plugins_generic(mock_config_generic_plugin):
|
||||
# Test that the function returns the correct number of plugins
|
||||
result = scan_plugins(mock_config_generic_plugin, debug=True)
|
||||
assert len(result) == 1
|
Loading…
Reference in New Issue
Block a user