diff --git a/.gitignore b/.gitignore index 55e18c91f..78d507abd 100644 --- a/.gitignore +++ b/.gitignore @@ -149,4 +149,6 @@ pilot/mock_datas/db-gpt-test.db.wal logswebserver.log.* .history/* -.plugin_env \ No newline at end of file +.plugin_env +/pilot/meta_data/alembic/versions/* +/pilot/meta_data/*.db \ No newline at end of file diff --git a/pilot/agent/__init__.py b/pilot/agent/__init__.py deleted file mode 100644 index c53f601b3..000000000 --- a/pilot/agent/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- diff --git a/pilot/agent/agent.py b/pilot/agent/agent.py deleted file mode 100644 index 8d8220b4a..000000000 --- a/pilot/agent/agent.py +++ /dev/null @@ -1,12 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- - - -class Agent: - """Agent class for interacting with DB-GPT - - Attributes: - """ - - def __init__(self) -> None: - pass diff --git a/pilot/agent/agent_manager.py b/pilot/agent/agent_manager.py deleted file mode 100644 index 31b55eb65..000000000 --- a/pilot/agent/agent_manager.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -"""Agent manager for managing GPT agents""" -from __future__ import annotations - -from pilot.configs.config import Config -from pilot.model.base import Message -from pilot.singleton import Singleton - - -class AgentManager(metaclass=Singleton): - """Agent manager for managing GPT agents""" - - def __init__(self): - self.next_key = 0 - self.agents = {} # key, (task, full_message_history, model) - self.cfg = Config() - - """Agent manager for managing DB-GPT agents - In order to compatible auto gpt plugins, - we use the same template with it. - - Args: next_keys - agents - cfg - """ - - def __init__(self) -> None: - self.next_key = 0 - self.agents = {} # TODO need to define - self.cfg = Config() - - # Create new GPT agent - # TODO: Centralise use of create_chat_completion() to globally enforce token limit - - def create_agent(self, task: str, prompt: str, model: str) -> tuple[int, str]: - """Create a new agent and return its key - - Args: - task: The task to perform - prompt: The prompt to use - model: The model to use - - Returns: - The key of the new agent - """ - - def message_agent(self, key: str | int, message: str) -> str: - """Send a message to an agent and return its response - - Args: - key: The key of the agent to message - message: The message to send to the agent - - Returns: - The agent's response - """ - - def list_agents(self) -> list[tuple[str | int, str]]: - """Return a list of all agents - - Returns: - A list of tuples of the form (key, task) - """ - - # Return a list of agent keys and their tasks - return [(key, task) for key, (task, _, _) in self.agents.items()] - - def delete_agent(self, key: str | int) -> bool: - """Delete an agent from the agent manager - - Args: - key: The key of the agent to delete - - Returns: - True if successful, False otherwise - """ - - try: - del self.agents[int(key)] - return True - except KeyError: - return False diff --git a/pilot/agent/json_fix_llm.py b/pilot/agent/json_fix_llm.py deleted file mode 100644 index aa7c4cec8..000000000 --- a/pilot/agent/json_fix_llm.py +++ /dev/null @@ -1,122 +0,0 @@ -import contextlib -import json -from typing import Any, Dict - -from colorama import Fore -from regex import regex - -from pilot.configs.config import Config -from pilot.json_utils.json_fix_general import ( - add_quotes_to_property_names, - balance_braces, - fix_invalid_escape, -) -from pilot.logs import logger - - -CFG = Config() - - -def fix_and_parse_json( - json_to_load: str, try_to_fix_with_gpt: bool = True -) -> Dict[Any, Any]: - """Fix and parse JSON string - - Args: - json_to_load (str): The JSON string. - try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT. - Defaults to True. - - Returns: - str or dict[Any, Any]: The parsed JSON. - """ - - with contextlib.suppress(json.JSONDecodeError): - json_to_load = json_to_load.replace("\t", "") - return json.loads(json_to_load) - - with contextlib.suppress(json.JSONDecodeError): - json_to_load = correct_json(json_to_load) - return json.loads(json_to_load) - # Let's do something manually: - # sometimes GPT responds with something BEFORE the braces: - # "I'm sorry, I don't understand. Please try again." - # {"text": "I'm sorry, I don't understand. Please try again.", - # "confidence": 0.0} - # So let's try to find the first brace and then parse the rest - # of the string - try: - brace_index = json_to_load.index("{") - maybe_fixed_json = json_to_load[brace_index:] - last_brace_index = maybe_fixed_json.rindex("}") - maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1] - return json.loads(maybe_fixed_json) - except (json.JSONDecodeError, ValueError) as e: - logger.error("参数解析错误", e) - - -def correct_json(json_to_load: str) -> str: - """ - Correct common JSON errors. - Args: - json_to_load (str): The JSON string. - """ - - try: - logger.debug("json", json_to_load) - json.loads(json_to_load) - return json_to_load - except json.JSONDecodeError as e: - logger.debug("json loads error", e) - error_message = str(e) - if error_message.startswith("Invalid \\escape"): - json_to_load = fix_invalid_escape(json_to_load, error_message) - if error_message.startswith( - "Expecting property name enclosed in double quotes" - ): - json_to_load = add_quotes_to_property_names(json_to_load) - try: - json.loads(json_to_load) - return json_to_load - except json.JSONDecodeError as e: - logger.debug("json loads error - add quotes", e) - error_message = str(e) - if balanced_str := balance_braces(json_to_load): - return balanced_str - return json_to_load - - -def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str): - from pilot.speech.say import say_text - - if CFG.speak_mode and CFG.debug_mode: - say_text( - "I have received an invalid JSON response from the OpenAI API. " - "Trying to fix it now." - ) - logger.error("Attempting to fix JSON by finding outermost brackets\n") - - try: - json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}") - json_match = json_pattern.search(json_string) - - if json_match: - # Extract the valid JSON object from the string - json_string = json_match.group(0) - logger.typewriter_log( - title="Apparently json was fixed.", title_color=Fore.GREEN - ) - if CFG.speak_mode and CFG.debug_mode: - say_text("Apparently json was fixed.") - else: - return {} - - except (json.JSONDecodeError, ValueError): - if CFG.debug_mode: - logger.error(f"Error: Invalid JSON: {json_string}\n") - if CFG.speak_mode: - say_text("Didn't work. I will have to ignore this response then.") - logger.error("Error: Invalid JSON, setting it to empty JSON now.\n") - json_string = {} - - return fix_and_parse_json(json_string) diff --git a/pilot/base_modules/__init__.py b/pilot/base_modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/base_modules/agent/__init__.py b/pilot/base_modules/agent/__init__.py new file mode 100644 index 000000000..017425ac0 --- /dev/null +++ b/pilot/base_modules/agent/__init__.py @@ -0,0 +1,6 @@ +from .db.my_plugin_db import MyPluginEntity, MyPluginDao +from .db.plugin_hub_db import PluginHubEntity, PluginHubDao + +from .commands.command import execute_command, get_command +from .commands.generator import PluginPromptGenerator +from .commands.disply_type.show_chart_gen import static_message_img_path \ No newline at end of file diff --git a/pilot/base_modules/agent/agent.py b/pilot/base_modules/agent/agent.py new file mode 100644 index 000000000..d11f21ab5 --- /dev/null +++ b/pilot/base_modules/agent/agent.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + + + + + + +class AgentFacade(ABC): + def __init__(self) -> None: + self.model = None + + + diff --git a/pilot/base_modules/agent/commands/__init__.py b/pilot/base_modules/agent/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/base_modules/agent/commands/built_in/__init__.py b/pilot/base_modules/agent/commands/built_in/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/base_modules/agent/commands/built_in/audio_text.py b/pilot/base_modules/agent/commands/built_in/audio_text.py new file mode 100644 index 000000000..5c27ba9da --- /dev/null +++ b/pilot/base_modules/agent/commands/built_in/audio_text.py @@ -0,0 +1,61 @@ +"""Commands for converting audio to text.""" +import json + +import requests + +from pilot.base_modules.agent.commands.command_mange import command +from pilot.configs.config import Config + +CFG = Config() + + +@command( + "read_audio_from_file", + "Convert Audio to text", + '"filename": ""', + CFG.huggingface_audio_to_text_model, + "Configure huggingface_audio_to_text_model.", +) +def read_audio_from_file(filename: str) -> str: + """ + Convert audio to text. + + Args: + filename (str): The path to the audio file + + Returns: + str: The text from the audio + """ + with open(filename, "rb") as audio_file: + audio = audio_file.read() + return read_audio(audio) + + +def read_audio(audio: bytes) -> str: + """ + Convert audio to text. + + Args: + audio (bytes): The audio to convert + + Returns: + str: The text from the audio + """ + model = CFG.huggingface_audio_to_text_model + api_url = f"https://api-inference.huggingface.co/models/{model}" + api_token = CFG.huggingface_api_token + headers = {"Authorization": f"Bearer {api_token}"} + + if api_token is None: + raise ValueError( + "You need to set your Hugging Face API token in the config file." + ) + + response = requests.post( + api_url, + headers=headers, + data=audio, + ) + + text = json.loads(response.content.decode("utf-8"))["text"] + return f"The audio says: {text}" diff --git a/pilot/base_modules/agent/commands/built_in/image_gen.py b/pilot/base_modules/agent/commands/built_in/image_gen.py new file mode 100644 index 000000000..2f6fe62d3 --- /dev/null +++ b/pilot/base_modules/agent/commands/built_in/image_gen.py @@ -0,0 +1,123 @@ +""" Image Generation Module for AutoGPT.""" +import io +import uuid +from base64 import b64decode + +import requests +from PIL import Image + +from pilot.base_modules.agent.commands.command_mange import command +from pilot.configs.config import Config +from pilot.logs import logger + +CFG = Config() + + +@command("generate_image", "Generate Image", '"prompt": ""', CFG.image_provider) +def generate_image(prompt: str, size: int = 256) -> str: + """Generate an image from a prompt. + + Args: + prompt (str): The prompt to use + size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace) + + Returns: + str: The filename of the image + """ + filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg" + + # HuggingFace + if CFG.image_provider == "huggingface": + return generate_image_with_hf(prompt, filename) + # SD WebUI + elif CFG.image_provider == "sdwebui": + return generate_image_with_sd_webui(prompt, filename, size) + return "No Image Provider Set" + + +def generate_image_with_hf(prompt: str, filename: str) -> str: + """Generate an image with HuggingFace's API. + + Args: + prompt (str): The prompt to use + filename (str): The filename to save the image to + + Returns: + str: The filename of the image + """ + API_URL = ( + f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}" + ) + if CFG.huggingface_api_token is None: + raise ValueError( + "You need to set your Hugging Face API token in the config file." + ) + headers = { + "Authorization": f"Bearer {CFG.huggingface_api_token}", + "X-Use-Cache": "false", + } + + response = requests.post( + API_URL, + headers=headers, + json={ + "inputs": prompt, + }, + ) + + image = Image.open(io.BytesIO(response.content)) + logger.info(f"Image Generated for prompt:{prompt}") + + image.save(filename) + + return f"Saved to disk:{filename}" + + +def generate_image_with_sd_webui( + prompt: str, + filename: str, + size: int = 512, + negative_prompt: str = "", + extra: dict = {}, +) -> str: + """Generate an image with Stable Diffusion webui. + Args: + prompt (str): The prompt to use + filename (str): The filename to save the image to + size (int, optional): The size of the image. Defaults to 256. + negative_prompt (str, optional): The negative prompt to use. Defaults to "". + extra (dict, optional): Extra parameters to pass to the API. Defaults to {}. + Returns: + str: The filename of the image + """ + # Create a session and set the basic auth if needed + s = requests.Session() + if CFG.sd_webui_auth: + username, password = CFG.sd_webui_auth.split(":") + s.auth = (username, password or "") + + # Generate the images + response = requests.post( + f"{CFG.sd_webui_url}/sdapi/v1/txt2img", + json={ + "prompt": prompt, + "negative_prompt": negative_prompt, + "sampler_index": "DDIM", + "steps": 20, + "cfg_scale": 7.0, + "width": size, + "height": size, + "n_iter": 1, + **extra, + }, + ) + + logger.info(f"Image Generated for prompt:{prompt}") + + # Save the image to disk + response = response.json() + b64 = b64decode(response["images"][0].split(",", 1)[0]) + image = Image.open(io.BytesIO(b64)) + image.save(filename) + + return f"Saved to disk:{filename}" diff --git a/pilot/base_modules/agent/commands/command.py b/pilot/base_modules/agent/commands/command.py new file mode 100644 index 000000000..b4cb0c0f6 --- /dev/null +++ b/pilot/base_modules/agent/commands/command.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import json +from typing import Dict + +from .exception_not_commands import NotCommands +from .generator import PluginPromptGenerator + +from pilot.configs.config import Config + +def _resolve_pathlike_command_args(command_args): + if "directory" in command_args and command_args["directory"] in {"", "/"}: + # todo + command_args["directory"] = "" + else: + for pathlike in ["filename", "directory", "clone_path"]: + if pathlike in command_args: + # todo + command_args[pathlike] = "" + return command_args + + +def execute_ai_response_json( + prompt: PluginPromptGenerator, + ai_response, + user_input: str = None, +) -> str: + """ + + Args: + command_registry: + ai_response: + prompt: + + Returns: + + """ + from pilot.speech.say import say_text + + cfg = Config() + + command_name, arguments = get_command(ai_response) + + if cfg.speak_mode: + say_text(f"I want to execute {command_name}") + + arguments = _resolve_pathlike_command_args(arguments) + # Execute command + if command_name is not None and command_name.lower().startswith("error"): + result = f"Command {command_name} threw the following error: {arguments}" + elif command_name == "human_feedback": + result = f"Human feedback: {user_input}" + else: + for plugin in cfg.plugins: + if not plugin.can_handle_pre_command(): + continue + command_name, arguments = plugin.pre_command(command_name, arguments) + command_result = execute_command( + command_name, + arguments, + prompt, + ) + result = f"{command_result}" + return result + + +def execute_command( + command_name: str, + arguments, + prompt: PluginPromptGenerator, +): + """Execute the command and return the result + + Args: + command_name (str): The name of the command to execute + arguments (dict): The arguments for the command + + Returns: + str: The result of the command + """ + + cmd = prompt.command_registry.commands.get(command_name) + + # If the command is found, call it with the provided arguments + if cmd: + try: + return cmd(**arguments) + except Exception as e: + return f"Error: {str(e)}" + # TODO: Change these to take in a file rather than pasted code, if + # non-file is given, return instructions "Input should be a python + # filepath, write your code to file and try again + else: + for command in prompt.commands: + if ( + command_name == command["label"].lower() + or command_name == command["name"].lower() + ): + try: + # 删除非定义参数 + diff_ags = list( + set(arguments.keys()).difference(set(command["args"].keys())) + ) + for arg_name in diff_ags: + del arguments[arg_name] + print(str(arguments)) + return command["function"](**arguments) + except Exception as e: + return f"Error: {str(e)}" + raise NotCommands("非可用命令" + command_name) + + +def get_command(response_json: Dict): + """Parse the response and return the command name and arguments + + Args: + response_json (json): The response from the AI + + Returns: + tuple: The command name and arguments + + Raises: + json.decoder.JSONDecodeError: If the response is not valid JSON + + Exception: If any other error occurs + """ + try: + if "command" not in response_json: + return "Error:", "Missing 'command' object in JSON" + + if not isinstance(response_json, dict): + return "Error:", f"'response_json' object is not dictionary {response_json}" + + command = response_json["command"] + if not isinstance(command, dict): + return "Error:", "'command' object is not a dictionary" + + if "name" not in command: + return "Error:", "Missing 'name' field in 'command' object" + + command_name = command["name"] + + # Use an empty dictionary if 'args' field is not present in 'command' object + arguments = command.get("args", {}) + + return command_name, arguments + except json.decoder.JSONDecodeError: + return "Error:", "Invalid JSON" + # All other errors, return "Error: + error message" + except Exception as e: + return "Error:", str(e) diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py new file mode 100644 index 000000000..22ebace5a --- /dev/null +++ b/pilot/base_modules/agent/commands/command_mange.py @@ -0,0 +1,156 @@ +import functools +import importlib +import inspect +from typing import Any, Callable, Optional + +# Unique identifier for auto-gpt commands +AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command" + + +class Command: + """A class representing a command. + + Attributes: + name (str): The name of the command. + description (str): A brief description of what the command does. + signature (str): The signature of the function that the command executes. Defaults to None. + """ + + def __init__( + self, + name: str, + description: str, + method: Callable[..., Any], + signature: str = "", + enabled: bool = True, + disabled_reason: Optional[str] = None, + ): + self.name = name + self.description = description + self.method = method + self.signature = signature if signature else str(inspect.signature(self.method)) + self.enabled = enabled + self.disabled_reason = disabled_reason + + def __call__(self, *args, **kwargs) -> Any: + if not self.enabled: + return f"Command '{self.name}' is disabled: {self.disabled_reason}" + return self.method(*args, **kwargs) + + def __str__(self) -> str: + return f"{self.name}: {self.description}, args: {self.signature}" + + +class CommandRegistry: + """ + The CommandRegistry class is a manager for a collection of Command objects. + It allows the registration, modification, and retrieval of Command objects, + as well as the scanning and loading of command plugins from a specified + directory. + """ + + def __init__(self): + self.commands = {} + + def _import_module(self, module_name: str) -> Any: + return importlib.import_module(module_name) + + def _reload_module(self, module: Any) -> Any: + return importlib.reload(module) + + def register(self, cmd: Command) -> None: + self.commands[cmd.name] = cmd + + def unregister(self, command_name: str): + if command_name in self.commands: + del self.commands[command_name] + else: + raise KeyError(f"Command '{command_name}' not found in registry.") + + def reload_commands(self) -> None: + """Reloads all loaded command plugins.""" + for cmd_name in self.commands: + cmd = self.commands[cmd_name] + module = self._import_module(cmd.__module__) + reloaded_module = self._reload_module(module) + if hasattr(reloaded_module, "register"): + reloaded_module.register(self) + + def get_command(self, name: str) -> Callable[..., Any]: + return self.commands[name] + + def call(self, command_name: str, **kwargs) -> Any: + if command_name not in self.commands: + raise KeyError(f"Command '{command_name}' not found in registry.") + command = self.commands[command_name] + return command(**kwargs) + + def command_prompt(self) -> str: + """ + Returns a string representation of all registered `Command` objects for use in a prompt + """ + commands_list = [ + f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values()) + ] + return "\n".join(commands_list) + + def import_commands(self, module_name: str) -> None: + """ + Imports the specified Python module containing command plugins. + + This method imports the associated module and registers any functions or + classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute + as `Command` objects. The registered `Command` objects are then added to the + `commands` dictionary of the `CommandRegistry` object. + + Args: + module_name (str): The name of the module to import for command plugins. + """ + + module = importlib.import_module(module_name) + + for attr_name in dir(module): + attr = getattr(module, attr_name) + # Register decorated functions + if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr( + attr, AUTO_GPT_COMMAND_IDENTIFIER + ): + self.register(attr.command) + # Register command classes + elif ( + inspect.isclass(attr) and issubclass(attr, Command) and attr != Command + ): + cmd_instance = attr() + self.register(cmd_instance) + + +def command( + name: str, + description: str, + signature: str = "", + enabled: bool = True, + disabled_reason: Optional[str] = None, +) -> Callable[..., Any]: + """The command decorator is used to create Command objects from ordinary functions.""" + + def decorator(func: Callable[..., Any]) -> Command: + cmd = Command( + name=name, + description=description, + method=func, + signature=signature, + enabled=enabled, + disabled_reason=disabled_reason, + ) + + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: + return func(*args, **kwargs) + + wrapper.command = cmd + + setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True) + + return wrapper + + return decorator diff --git a/pilot/base_modules/agent/commands/disply_type/__init__.py b/pilot/base_modules/agent/commands/disply_type/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py new file mode 100644 index 000000000..941586328 --- /dev/null +++ b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py @@ -0,0 +1,296 @@ +from pandas import DataFrame + +from pilot.base_modules.agent.commands.command_mange import command +from pilot.configs.config import Config +import pandas as pd +import uuid +import os +import matplotlib +import seaborn as sns + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.ticker as mtick +from matplotlib.font_manager import FontManager + +from pilot.configs.model_config import LOGDIR +from pilot.utils import build_logger + +CFG = Config() +logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log") +static_message_img_path = os.path.join(os.getcwd(), "message/img") + + +def data_pre_classification(df: DataFrame): + ## Data pre-classification + columns = df.columns.tolist() + + number_columns = [] + non_numeric_colums = [] + + # 收集数据分类小于10个的列 + non_numeric_colums_value_map = {} + numeric_colums_value_map = {} + for column_name in columns: + if pd.api.types.is_numeric_dtype(df[column_name].dtypes): + number_columns.append(column_name) + unique_values = df[column_name].unique() + numeric_colums_value_map.update({column_name: len(unique_values)}) + else: + non_numeric_colums.append(column_name) + unique_values = df[column_name].unique() + non_numeric_colums_value_map.update({column_name: len(unique_values)}) + + sorted_numeric_colums_value_map = dict( + sorted(numeric_colums_value_map.items(), key=lambda x: x[1]) + ) + numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys()) + + sorted_colums_value_map = dict( + sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]) + ) + non_numeric_colums_sort_list = list(sorted_colums_value_map.keys()) + + # Analyze x-coordinate + if len(non_numeric_colums_sort_list) > 0: + x_cloumn = non_numeric_colums_sort_list[-1] + non_numeric_colums_sort_list.remove(x_cloumn) + else: + x_cloumn = number_columns[0] + numeric_colums_sort_list.remove(x_cloumn) + + # Analyze y-coordinate + if len(numeric_colums_sort_list) > 0: + y_column = numeric_colums_sort_list[0] + numeric_colums_sort_list.remove(y_column) + else: + raise ValueError("Not enough numeric columns for chart!") + + return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list + + +def zh_font_set(): + font_names = [ + "Heiti TC", + "Songti SC", + "STHeiti Light", + "Microsoft YaHei", + "SimSun", + "SimHei", + "KaiTi", + ] + fm = FontManager() + mat_fonts = set(f.name for f in fm.ttflist) + can_use_fonts = [] + for font_name in font_names: + if font_name in mat_fonts: + can_use_fonts.append(font_name) + if len(can_use_fonts) > 0: + plt.rcParams["font.sans-serif"] = can_use_fonts + + +@command( + "response_line_chart", + "Line chart display, used to display comparative trend analysis data", + '"speak": "", "df":""', +) +def response_line_chart(speak: str, df: DataFrame) -> str: + logger.info(f"response_line_chart:{speak},") + if df.size <= 0: + raise ValueError("No Data!") + + # set font + # zh_font_set() + font_names = [ + "Heiti TC", + "Songti SC", + "STHeiti Light", + "Microsoft YaHei", + "SimSun", + "SimHei", + "KaiTi", + ] + fm = FontManager() + mat_fonts = set(f.name for f in fm.ttflist) + can_use_fonts = [] + for font_name in font_names: + if font_name in mat_fonts: + can_use_fonts.append(font_name) + if len(can_use_fonts) > 0: + plt.rcParams["font.sans-serif"] = can_use_fonts + + rc = {"font.sans-serif": can_use_fonts} + plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题 + + sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题 + sns.set_palette("Set3") # 设置颜色主题 + sns.set_style("dark") + sns.color_palette("hls", 10) + sns.hls_palette(8, l=0.5, s=0.7) + sns.set(context="notebook", style="ticks", rc=rc) + + fig, ax = plt.subplots(figsize=(8, 5), dpi=100) + x, y, non_num_columns, num_colmns = data_pre_classification(df) + # ## 复杂折线图实现 + if len(num_colmns) > 0: + num_colmns.append(y) + df_melted = pd.melt( + df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value" + ) + sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2") + else: + sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2") + + ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y))) + ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x))) + + chart_name = "line_" + str(uuid.uuid1()) + ".png" + chart_path = static_message_img_path + "/" + chart_name + plt.savefig(chart_path, bbox_inches="tight", dpi=100) + + html_img = f"""
{speak}
""" + return html_img + + +@command( + "response_bar_chart", + "Histogram, suitable for comparative analysis of multiple target values", + '"speak": "", "df":""', +) +def response_bar_chart(speak: str, df: DataFrame) -> str: + logger.info(f"response_bar_chart:{speak},") + if df.size <= 0: + raise ValueError("No Data!") + + # set font + # zh_font_set() + font_names = [ + "Heiti TC", + "Songti SC", + "STHeiti Light", + "Microsoft YaHei", + "SimSun", + "SimHei", + "KaiTi", + ] + fm = FontManager() + mat_fonts = set(f.name for f in fm.ttflist) + can_use_fonts = [] + for font_name in font_names: + if font_name in mat_fonts: + can_use_fonts.append(font_name) + if len(can_use_fonts) > 0: + plt.rcParams["font.sans-serif"] = can_use_fonts + + rc = {"font.sans-serif": can_use_fonts} + plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题 + sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题 + sns.set_palette("Set3") # 设置颜色主题 + sns.set_style("dark") + sns.color_palette("hls", 10) + sns.hls_palette(8, l=0.5, s=0.7) + sns.set(context="notebook", style="ticks", rc=rc) + + fig, ax = plt.subplots(figsize=(8, 5), dpi=100) + + hue = None + x, y, non_num_columns, num_colmns = data_pre_classification(df) + if len(non_num_columns) >= 1: + hue = non_num_columns[0] + + if len(num_colmns) >= 1: + if hue: + if len(num_colmns) >= 2: + can_use_columns = num_colmns[:2] + else: + can_use_columns = num_colmns + sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) + for sub_y_column in can_use_columns: + sns.barplot( + data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax + ) + else: + if len(num_colmns) > 5: + can_use_columns = num_colmns[:5] + else: + can_use_columns = num_colmns + can_use_columns.append(y) + + df_melted = pd.melt( + df, + id_vars=x, + value_vars=can_use_columns, + var_name="line", + value_name="Value", + ) + sns.barplot( + data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax + ) + else: + sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) + + # 设置 y 轴刻度格式为普通数字格式 + ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y))) + ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x))) + + chart_name = "bar_" + str(uuid.uuid1()) + ".png" + chart_path = static_message_img_path + "/" + chart_name + plt.savefig(chart_path, bbox_inches="tight", dpi=100) + html_img = f"""
{speak}
""" + return html_img + + +@command( + "response_pie_chart", + "Pie chart, suitable for scenarios such as proportion and distribution statistics", + '"speak": "", "df":""', +) +def response_pie_chart(speak: str, df: DataFrame) -> str: + logger.info(f"response_pie_chart:{speak},") + columns = df.columns.tolist() + if df.size <= 0: + raise ValueError("No Data!") + # set font + # zh_font_set() + font_names = [ + "Heiti TC", + "Songti SC", + "STHeiti Light", + "Microsoft YaHei", + "SimSun", + "SimHei", + "KaiTi", + ] + fm = FontManager() + mat_fonts = set(f.name for f in fm.ttflist) + can_use_fonts = [] + for font_name in font_names: + if font_name in mat_fonts: + can_use_fonts.append(font_name) + if len(can_use_fonts) > 0: + plt.rcParams["font.sans-serif"] = can_use_fonts + plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题 + + sns.set_palette("Set3") # 设置颜色主题 + + # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) + fig, ax = plt.subplots(figsize=(8, 5), dpi=100) + ax = df.plot( + kind="pie", + y=columns[1], + ax=ax, + labels=df[columns[0]].values, + startangle=90, + autopct="%1.1f%%", + ) + + plt.axis("equal") # 使饼图为正圆形 + # plt.title(columns[0]) + + chart_name = "pie_" + str(uuid.uuid1()) + ".png" + chart_path = static_message_img_path + "/" + chart_name + plt.savefig(chart_path, bbox_inches="tight", dpi=100) + + html_img = f"""
{speak.replace("`", '"')}
""" + + return html_img diff --git a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py new file mode 100644 index 000000000..45f9f2f21 --- /dev/null +++ b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py @@ -0,0 +1,24 @@ +from pandas import DataFrame + +from pilot.base_modules.agent.commands.command_mange import command +from pilot.configs.config import Config + +from pilot.configs.model_config import LOGDIR +from pilot.utils import build_logger + +CFG = Config() +logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") + + +@command( + "response_table", + "Table display, suitable for display with many display columns or non-numeric columns", + '"speak": "", "df":""', +) +def response_table(speak: str, df: DataFrame) -> str: + logger.info(f"response_table:{speak}") + html_table = df.to_html(index=False, escape=False, sparsify=False) + table_str = "".join(html_table.split()) + html = f"""
{table_str}
""" + view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") + return view_text diff --git a/pilot/base_modules/agent/commands/disply_type/show_text_gen.py b/pilot/base_modules/agent/commands/disply_type/show_text_gen.py new file mode 100644 index 000000000..f90ef087e --- /dev/null +++ b/pilot/base_modules/agent/commands/disply_type/show_text_gen.py @@ -0,0 +1,39 @@ +from pandas import DataFrame + +from pilot.base_modules.agent.commands.command_mange import command +from pilot.configs.config import Config +from pilot.configs.model_config import LOGDIR +from pilot.utils import build_logger + +CFG = Config() +logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") + + +@command( + "response_data_text", + "Text display, the default display method, suitable for single-line or simple content display", + '"speak": "", "df":""', +) +def response_data_text(speak: str, df: DataFrame) -> str: + logger.info(f"response_data_text:{speak}") + data = df.values + + row_size = data.shape[0] + value_str = "" + text_info = "" + if row_size > 1: + html_table = df.to_html(index=False, escape=False, sparsify=False) + table_str = "".join(html_table.split()) + html = f"""
{table_str}
""" + text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") + elif row_size == 1: + row = data[0] + for value in row: + if value_str: + value_str = value_str + f", ** {value} **" + else: + value_str = f" ** {value} **" + text_info = f"{speak}: {value_str}" + else: + text_info = f"##### {speak}: _没有找到可用的数据_" + return text_info diff --git a/pilot/base_modules/agent/commands/exception_not_commands.py b/pilot/base_modules/agent/commands/exception_not_commands.py new file mode 100644 index 000000000..7d92f05c0 --- /dev/null +++ b/pilot/base_modules/agent/commands/exception_not_commands.py @@ -0,0 +1,4 @@ +class NotCommands(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message diff --git a/pilot/base_modules/agent/commands/generator.py b/pilot/base_modules/agent/commands/generator.py new file mode 100644 index 000000000..6752cd1e1 --- /dev/null +++ b/pilot/base_modules/agent/commands/generator.py @@ -0,0 +1,158 @@ +""" A module for generating custom prompt strings.""" +import json +from typing import Any, Callable, Dict, List, Optional + + +class PluginPromptGenerator: + """ + A class for generating custom prompt strings based on constraints, commands, + resources, and performance evaluations. + """ + + def __init__(self) -> None: + """ + Initialize the PromptGenerator object with empty lists of constraints, + commands, resources, and performance evaluations. + """ + self.constraints = [] + self.commands = [] + self.resources = [] + self.performance_evaluation = [] + self.goals = [] + self.command_registry = None + self.name = "Bob" + self.role = "AI" + self.response_format = { + "thoughts": { + "text": "thought", + "reasoning": "reasoning", + "plan": "- short bulleted\n- list that conveys\n- long-term plan", + "criticism": "constructive self-criticism", + "speak": "thoughts summary to say to user", + }, + "command": {"name": "command name", "args": {"arg name": "value"}}, + } + + def add_constraint(self, constraint: str) -> None: + """ + Add a constraint to the constraints list. + + Args: + constraint (str): The constraint to be added. + """ + self.constraints.append(constraint) + + def add_command( + self, + command_label: str, + command_name: str, + args=None, + function: Optional[Callable] = None, + ) -> None: + """ + Add a command to the commands list with a label, name, and optional arguments. + + Args: + command_label (str): The label of the command. + command_name (str): The name of the command. + args (dict, optional): A dictionary containing argument names and their + values. Defaults to None. + function (callable, optional): A callable function to be called when + the command is executed. Defaults to None. + """ + if args is None: + args = {} + + command_args = {arg_key: arg_value for arg_key, arg_value in args.items()} + + command = { + "label": command_label, + "name": command_name, + "args": command_args, + "function": function, + } + + self.commands.append(command) + + def _generate_command_string(self, command: Dict[str, Any]) -> str: + """ + Generate a formatted string representation of a command. + + Args: + command (dict): A dictionary containing command information. + + Returns: + str: The formatted command string. + """ + args_string = ", ".join( + f'"{key}": "{value}"' for key, value in command["args"].items() + ) + return f'{command["label"]}: "{command["name"]}", args: {args_string}' + + def add_resource(self, resource: str) -> None: + """ + Add a resource to the resources list. + + Args: + resource (str): The resource to be added. + """ + self.resources.append(resource) + + def add_performance_evaluation(self, evaluation: str) -> None: + """ + Add a performance evaluation item to the performance_evaluation list. + + Args: + evaluation (str): The evaluation item to be added. + """ + self.performance_evaluation.append(evaluation) + + def _generate_numbered_list(self, items: List[Any], item_type="list") -> str: + """ + Generate a numbered list from given items based on the item_type. + + Args: + items (list): A list of items to be numbered. + item_type (str, optional): The type of items in the list. + Defaults to 'list'. + + Returns: + str: The formatted numbered list. + """ + if item_type == "command": + command_strings = [] + if self.command_registry: + command_strings += [ + str(item) + for item in self.command_registry.commands.values() + if item.enabled + ] + # terminate command is added manually + command_strings += [self._generate_command_string(item) for item in items] + return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) + else: + return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items)) + + def generate_commands_string(self) -> str: + return f"{self._generate_numbered_list(self.commands, item_type='command')}" + + def generate_prompt_string(self) -> str: + """ + Generate a prompt string based on the constraints, commands, resources, + and performance evaluations. + + Returns: + str: The generated prompt string. + """ + formatted_response_format = json.dumps(self.response_format, indent=4) + return ( + f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n" + "Commands:\n" + f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n" + f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n" + "Performance Evaluation:\n" + f"{self._generate_numbered_list(self.performance_evaluation)}\n\n" + "You should only respond in JSON format as described below and ensure the" + "response can be parsed by Python json.loads \nResponse" + f" Format: \n{formatted_response_format}" + ) diff --git a/pilot/base_modules/agent/commands/times.py b/pilot/base_modules/agent/commands/times.py new file mode 100644 index 000000000..3c9b8a4fc --- /dev/null +++ b/pilot/base_modules/agent/commands/times.py @@ -0,0 +1,10 @@ +from datetime import datetime + + +def get_datetime() -> str: + """Return the current date and time + + Returns: + str: The current date and time + """ + return "Current date and time: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/pilot/base_modules/agent/controller.py b/pilot/base_modules/agent/controller.py new file mode 100644 index 000000000..cbf0e3b7b --- /dev/null +++ b/pilot/base_modules/agent/controller.py @@ -0,0 +1,26 @@ +import json +import time +from fastapi import ( + APIRouter, + Body, +) + +from typing import List +from pilot.configs.model_config import LOGDIR +from pilot.utils import build_logger + +from pilot.openapi.api_view_model import ( + Result, +) + + + +router = APIRouter() +logger = build_logger("agent_mange", LOGDIR + "agent_mange.log") + + +@router.get("/v1/mange/agent/list", response_model=Result[str]) +async def get_agent_list(): + logger.info(f"get_agent_list!") + + return Result.succ(None) diff --git a/pilot/base_modules/agent/db/my_plugin_db.py b/pilot/base_modules/agent/db/my_plugin_db.py new file mode 100644 index 000000000..28a50d0dc --- /dev/null +++ b/pilot/base_modules/agent/db/my_plugin_db.py @@ -0,0 +1,136 @@ +from datetime import datetime +from typing import List +from sqlalchemy import Column, Integer, String, Index, DateTime, func +from sqlalchemy import UniqueConstraint + +from pilot.base_modules.meta_data.meta_data import Base + +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session + + +class MyPluginEntity(Base): + __tablename__ = 'my_plugin' + + id = Column(Integer, primary_key=True, comment="autoincrement id") + tenant = Column(String, nullable=True, comment="user's tenant") + user_code = Column(String, nullable=True, comment="user code") + user_name = Column(String, nullable=True, comment="user name") + name = Column(String, unique=True, nullable=False, comment="plugin name") + type = Column(String, comment="plugin type") + version = Column(String, comment="plugin version") + use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count") + succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count") + created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time") + __table_args__ = ( + UniqueConstraint('name', name="uk_name"), + ) + + +class MyPluginDao(BaseDao[MyPluginEntity]): + def __init__(self): + super().__init__( + database="dbgpt", orm_base=Base, db_engine =engine , session= session + ) + + def add(self, engity: MyPluginEntity): + session = self.Session() + my_plugin = MyPluginEntity( + tenant=engity.tenant, + user_code=engity.user_code, + user_name=engity.user_name, + name=engity.name, + type=engity.type, + version=engity.version, + use_count=engity.use_count or 0, + succ_count=engity.succ_count or 0, + created_at=datetime.now(), + ) + session.add(my_plugin) + session.commit() + id = my_plugin.id + session.close() + return id + + def update(self, entity: MyPluginEntity): + session = self.Session() + updated = session.merge(entity) + session.commit() + return updated.id + + + def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]: + session = self.Session() + my_plugins = session.query(MyPluginEntity) + if query.id is not None: + my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) + if query.name is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.name == query.name + ) + if query.tenant is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.tenant == query.tenant + ) + if query.type is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.type == query.type + ) + if query.user_code is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.user_code == query.user_code + ) + if query.user_name is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.user_name == query.user_name + ) + + my_plugins = my_plugins.order_by(MyPluginEntity.id.desc()) + my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size) + result = my_plugins.all() + session.close() + return result + + def count(self, query: MyPluginEntity): + session = self.Session() + my_plugins = session.query(func.count(MyPluginEntity.id)) + if query.id is not None: + my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) + if query.name is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.name == query.name + ) + if query.type is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.type == query.type + ) + if query.tenant is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.tenant == query.tenant + ) + if query.user_code is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.user_code == query.user_code + ) + if query.user_name is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.user_name == query.user_name + ) + count = my_plugins.scalar() + session.close() + return count + + + def delete(self, plugin_id: int): + session = self.Session() + if plugin_id is None: + raise Exception("plugin_id is None") + query = MyPluginEntity(id=plugin_id) + my_plugins = session.query(MyPluginEntity) + if query.id is not None: + my_plugins = my_plugins.filter( + MyPluginEntity.id == query.id + ) + my_plugins.delete() + session.commit() + session.close() diff --git a/pilot/base_modules/agent/db/plugin_hub_db.py b/pilot/base_modules/agent/db/plugin_hub_db.py new file mode 100644 index 000000000..1dfb2363d --- /dev/null +++ b/pilot/base_modules/agent/db/plugin_hub_db.py @@ -0,0 +1,136 @@ +from datetime import datetime +from typing import List +from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean +from sqlalchemy import UniqueConstraint + +from pilot.base_modules.meta_data.meta_data import Base + +from pilot.base_modules.meta_data.base_dao import BaseDao +from pilot.base_modules.meta_data.meta_data import Base, engine, session + + +class PluginHubEntity(Base): + __tablename__ = 'plugin_hub' + id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") + name = Column(String, unique=True, nullable=False, comment="plugin name") + author = Column(String, nullable=True, comment="plugin author") + email = Column(String, nullable=True, comment="plugin author email") + type = Column(String, comment="plugin type") + version = Column(String, comment="plugin version") + storage_channel = Column(String, comment="plugin storage channel") + storage_url = Column(String, comment="plugin download url") + created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time") + installed = Column(Boolean, default=False, comment="plugin already installed") + + __table_args__ = ( + UniqueConstraint('name', name="uk_name"), + Index('idx_q_type', 'type'), + ) + + +class PluginHubDao(BaseDao[PluginHubEntity]): + def __init__(self): + super().__init__( + database="dbgpt", orm_base=Base, db_engine=engine, session=session + ) + + def add(self, engity: PluginHubEntity): + session = self.Session() + plugin_hub = PluginHubEntity( + name=engity.name, + author=engity.author, + email=engity.email, + type=engity.type, + version=engity.version, + storage_channel=engity.storage_channel, + storage_url=engity.storage_url, + created_at=datetime.now(), + ) + session.add(plugin_hub) + session.commit() + id = plugin_hub.id + session.close() + return id + + def update(self, entity: PluginHubEntity): + session = self.Session() + updated = session.merge(entity) + session.commit() + return updated.id + + def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]: + session = self.Session() + plugin_hubs = session.query(PluginHubEntity) + if query.id is not None: + plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id) + if query.name is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.name == query.name + ) + if query.type is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.type == query.type + ) + if query.author is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.author == query.author + ) + if query.storage_channel is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.storage_channel == query.storage_channel + ) + + plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc()) + plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size) + result = plugin_hubs.all() + session.close() + return result + + def get_by_name(self, name: str) -> PluginHubEntity: + session = self.Session() + plugin_hubs = session.query(PluginHubEntity) + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.name == name + ) + result = plugin_hubs.get(1) + session.close() + return result + + def count(self, query: PluginHubEntity): + session = self.Session() + plugin_hubs = session.query(func.count(PluginHubEntity.id)) + if query.id is not None: + plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id) + if query.name is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.name == query.name + ) + if query.type is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.type == query.type + ) + if query.author is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.author == query.author + ) + if query.storage_channel is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.storage_channel == query.storage_channel + ) + count = plugin_hubs.scalar() + session.close() + return count + + def delete(self, plugin_id: int): + session = self.Session() + if plugin_id is None: + raise Exception("plugin_id is None") + query = PluginHubEntity(id=plugin_id) + plugin_hubs = session.query(PluginHubEntity) + if query.id is not None: + plugin_hubs = plugin_hubs.filter( + PluginHubEntity.id == query.id + ) + plugin_hubs.delete() + session.commit() + session.close() diff --git a/pilot/base_modules/agent/hub/__init__.py b/pilot/base_modules/agent/hub/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/base_modules/agent/hub/agent_hub.py b/pilot/base_modules/agent/hub/agent_hub.py new file mode 100644 index 000000000..54e76dcca --- /dev/null +++ b/pilot/base_modules/agent/hub/agent_hub.py @@ -0,0 +1,69 @@ +import logging + +from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao +from ..db.my_plugin_db import MyPluginDao, MyPluginEntity +from .schema import PluginStorageType + +logger = logging.getLogger("agent_hub") + + +class AgentHub: + def __init__(self) -> None: + self.hub_dao = PluginHubDao() + self.my_lugin_dao = MyPluginDao() + + def install_plugin(self, plugin_name: str, user_name: str = None): + logger.info(f"install_plugin {plugin_name}") + + plugin_entity = self.hub_dao.get_by_name(plugin_name) + if plugin_entity: + if plugin_entity.storage_channel == PluginStorageType.Git.value: + try: + self.__download_from_git(plugin_name, plugin_entity.storage_url) + self.load_plugin(plugin_name) + + # add to my plugins and edit hub status + plugin_entity.installed = True + + my_plugin_entity = self.__build_my_plugin(plugin_entity) + if not user_name: + # TODO use user + my_plugin_entity.user_code = "" + my_plugin_entity.user_name = user_name + my_plugin_entity.tenant = "" + + with self.hub_dao.Session() as session: + try: + session.add(my_plugin_entity) + session.merge(plugin_entity) + session.commit() + except: + session.rollback() + except Exception as e: + logger.error("install pluguin exception!", e) + raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}") + + else: + raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!") + else: + raise ValueError(f"Can't Find Plugin {plugin_name}!") + + def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity: + my_plugin_entity = MyPluginEntity() + my_plugin_entity.name = hub_plugin.name + my_plugin_entity.type = hub_plugin.type + my_plugin_entity.version = hub_plugin.version + return my_plugin_entity + + def __download_from_git(self, plugin_name, url): + pass + + def load_plugin(self, plugin_name): + + pass + + def get_my_plugin(self, user: str): + pass + + def uninstall_plugin(self): + pass diff --git a/pilot/base_modules/agent/hub/schema.py b/pilot/base_modules/agent/hub/schema.py new file mode 100644 index 000000000..225714b86 --- /dev/null +++ b/pilot/base_modules/agent/hub/schema.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class PluginStorageType(Enum): + Git = "git" + Oss = "oss" \ No newline at end of file diff --git a/pilot/base_modules/agent/plugins.py b/pilot/base_modules/agent/plugins.py new file mode 100644 index 000000000..3ee5f4ac2 --- /dev/null +++ b/pilot/base_modules/agent/plugins.py @@ -0,0 +1,182 @@ +"""加载组件""" + +import json +import os +import glob +import zipfile +import requests +import threading +import datetime +from pathlib import Path +from typing import List +from urllib.parse import urlparse +from zipimport import zipimporter + +import requests +from auto_gpt_plugin_template import AutoGPTPluginTemplate + +from pilot.configs.config import Config +from pilot.configs.model_config import PLUGINS_DIR +from pilot.logs import logger + + +def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]: + """ + Loader zip plugin file. Native support Auto_gpt_plugin + + Args: + zip_path (str): Path to the zipfile. + debug (bool, optional): Enable debug logging. Defaults to False. + + Returns: + list[str]: The list of module names found or empty list if none were found. + """ + result = [] + with zipfile.ZipFile(zip_path, "r") as zfile: + for name in zfile.namelist(): + if name.endswith("__init__.py") and not name.startswith("__MACOSX"): + logger.debug(f"Found module '{name}' in the zipfile at: {name}") + result.append(name) + if len(result) == 0: + logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.") + return result + + +def write_dict_to_json_file(data: dict, file_path: str) -> None: + """ + Write a dictionary to a JSON file. + Args: + data (dict): Dictionary to write. + file_path (str): Path to the file. + """ + with open(file_path, "w") as file: + json.dump(data, file, indent=4) + + +def create_directory_if_not_exists(directory_path: str) -> bool: + """ + Create a directory if it does not exist. + Args: + directory_path (str): Path to the directory. + Returns: + bool: True if the directory was created, else False. + """ + if not os.path.exists(directory_path): + try: + os.makedirs(directory_path) + logger.debug(f"Created directory: {directory_path}") + return True + except OSError as e: + logger.warn(f"Error creating directory {directory_path}: {e}") + return False + else: + logger.info(f"Directory {directory_path} already exists") + return True + + +def load_native_plugins(cfg: Config): + if not cfg.plugins_auto_load: + print("not auto load_native_plugins") + return + + def load_from_git(cfg: Config): + print("async load_native_plugins") + branch_name = cfg.plugins_git_branch + native_plugin_repo = "DB-GPT-Plugins" + url = "https://github.com/csunny/{repo}/archive/{branch}.zip" + try: + session = requests.Session() + response = session.get( + url.format(repo=native_plugin_repo, branch=branch_name), + headers={"Authorization": "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"}, + ) + + if response.status_code == 200: + plugins_path_path = Path(PLUGINS_DIR) + files = glob.glob( + os.path.join(plugins_path_path, f"{native_plugin_repo}*") + ) + for file in files: + os.remove(file) + now = datetime.datetime.now() + time_str = now.strftime("%Y%m%d%H%M%S") + file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" + print(file_name) + with open(file_name, "wb") as f: + f.write(response.content) + print("save file") + cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) + else: + print("get file faild,response code:", response.status_code) + except Exception as e: + print("load plugin from git exception!" + str(e)) + + t = threading.Thread(target=load_from_git, args=(cfg,)) + t.start() + + +def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: + """Scan the plugins directory for plugins and loads them. + + Args: + cfg (Config): Config instance including plugins config + debug (bool, optional): Enable debug logging. Defaults to False. + + Returns: + List[Tuple[str, Path]]: List of plugins. + """ + loaded_plugins = [] + current_dir = os.getcwd() + print(current_dir) + # Generic plugins + plugins_path_path = Path(PLUGINS_DIR) + + for plugin in plugins_path_path.glob("*.zip"): + if moduleList := inspect_zip_for_modules(str(plugin), debug): + for module in moduleList: + plugin = Path(plugin) + module = Path(module) + logger.debug(f"Plugin: {plugin} Module: {module}") + zipped_package = zipimporter(str(plugin)) + zipped_module = zipped_package.load_module(str(module.parent)) + for key in dir(zipped_module): + if key.startswith("__"): + continue + a_module = getattr(zipped_module, key) + a_keys = dir(a_module) + if ( + "_abc_impl" in a_keys + and a_module.__name__ != "AutoGPTPluginTemplate" + # and denylist_allowlist_check(a_module.__name__, cfg) + ): + loaded_plugins.append(a_module()) + + if loaded_plugins: + logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------") + for plugin in loaded_plugins: + logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}") + return loaded_plugins + + +def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool: + """Check if the plugin is in the allowlist or denylist. + + Args: + plugin_name (str): Name of the plugin. + cfg (Config): Config object. + + Returns: + True or False + """ + logger.debug(f"Checking if plugin {plugin_name} should be loaded") + if plugin_name in cfg.plugins_denylist: + logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.") + return False + if plugin_name in cfg.plugins_allowlist: + logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.") + return True + ack = input( + f"WARNING: Plugin {plugin_name} found. But not in the" + f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): " + ) + return ack.lower() == cfg.authorise_key diff --git a/pilot/base_modules/agent/requirement.txt b/pilot/base_modules/agent/requirement.txt new file mode 100644 index 000000000..c4a6c89af --- /dev/null +++ b/pilot/base_modules/agent/requirement.txt @@ -0,0 +1,2 @@ +flask_sqlalchemy==3.0.5 +flask==2.3.2 \ No newline at end of file diff --git a/pilot/base_modules/base.py b/pilot/base_modules/base.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/base_modules/mange_base_api.py b/pilot/base_modules/mange_base_api.py new file mode 100644 index 000000000..c0b5da273 --- /dev/null +++ b/pilot/base_modules/mange_base_api.py @@ -0,0 +1,7 @@ +class ModuleMangeApi: + + def module_name(self): + pass + + def register(self): + pass \ No newline at end of file diff --git a/pilot/base_modules/meta_data/__init__.py b/pilot/base_modules/meta_data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/base_modules/meta_data/base_dao.py b/pilot/base_modules/meta_data/base_dao.py new file mode 100644 index 000000000..733b252aa --- /dev/null +++ b/pilot/base_modules/meta_data/base_dao.py @@ -0,0 +1,21 @@ +from typing import TypeVar, Generic, List, Any +from sqlalchemy.orm import sessionmaker + +T = TypeVar('T') + +class BaseDao(Generic[T]): + def __init__( + self, orm_base=None, database: str = None, db_engine: Any = None, session: Any = None, + ) -> None: + """BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist""" + self._orm_base = orm_base + self._database = database + + self._db_engine = db_engine + self._session = session + + @property + def Session(self): + if not self._session: + self._session = sessionmaker(bind=self.db_engine) + return self._session diff --git a/pilot/base_modules/meta_data/meta_data.py b/pilot/base_modules/meta_data/meta_data.py new file mode 100644 index 000000000..74edf13a5 --- /dev/null +++ b/pilot/base_modules/meta_data/meta_data.py @@ -0,0 +1,109 @@ +import uuid +import os +import duckdb +import sqlite3 +from datetime import datetime +from typing import Optional, Type, TypeVar + +import sqlalchemy as sa + +from flask import Flask +from flask_sqlalchemy import SQLAlchemy +from flask_migrate import Migrate,upgrade +from flask.cli import with_appcontext +import subprocess + +from sqlalchemy import create_engine,DateTime, String, func, MetaData +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.declarative import declarative_base + +from alembic import context, command +from alembic.config import Config + +default_db_path = os.path.join(os.getcwd(), "meta_data") + +os.makedirs(default_db_path, exist_ok=True) + +db_path = default_db_path + "/dbgpt.db" +connection = sqlite3.connect(db_path) +engine = create_engine(f'sqlite:///{db_path}') + +Session = sessionmaker(autocommit=False, autoflush=False, bind=engine) +session = Session() + +Base = declarative_base(bind=engine) + +# Base.metadata.create_all() + +# 创建Alembic配置对象 + +alembic_ini_path = default_db_path + "/alembic.ini" +alembic_cfg = Config(alembic_ini_path) + +alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url)) + +os.makedirs(default_db_path + "/alembic", exist_ok=True) +alembic_cfg.set_main_option('script_location', default_db_path + "/alembic") + +# 将模型和会话传递给Alembic配置 +alembic_cfg.attributes['target_metadata'] = Base.metadata +alembic_cfg.attributes['session'] = session + + +# # 创建表 +# Base.metadata.create_all(engine) +# +# # 删除表 +# Base.metadata.drop_all(engine) + + +# app = Flask(__name__) +# default_db_path = os.path.join(os.getcwd(), "meta_data") +# duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/dbgpt.db") +# app.config['SQLALCHEMY_DATABASE_URI'] = f'duckdb://{duckdb_path}' +# app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False +# db = SQLAlchemy(app) +# migrate = Migrate(app, db) +# +# # 设置FLASK_APP环境变量 +# import os +# os.environ['FLASK_APP'] = 'server.dbgpt_server.py' +# +# @app.cli.command("db_init") +# @with_appcontext +# def db_init(): +# subprocess.run(["flask", "db", "init"]) +# +# @app.cli.command("db_migrate") +# @with_appcontext +# def db_migrate(): +# subprocess.run(["flask", "db", "migrate"]) +# +# @app.cli.command("db_upgrade") +# @with_appcontext +# def db_upgrade(): +# subprocess.run(["flask", "db", "upgrade"]) +# + + + +def ddl_init_and_upgrade(): + # Base.metadata.create_all(bind=engine) + # 生成并应用迁移脚本 + # command.upgrade(alembic_cfg, 'head') + # subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"]) + with engine.connect() as connection: + alembic_cfg.attributes['connection'] = connection + command.revision(alembic_cfg, "test", True) + command.upgrade(alembic_cfg, "head") + # alembic_cfg.attributes['connection'] = engine.connect() + # command.upgrade(alembic_cfg, 'head') + + # with app.app_context(): + # db_init() + # db_migrate() + # db_upgrade() + + diff --git a/pilot/base_modules/meta_data/requirement.txt b/pilot/base_modules/meta_data/requirement.txt new file mode 100644 index 000000000..de93d8a90 --- /dev/null +++ b/pilot/base_modules/meta_data/requirement.txt @@ -0,0 +1 @@ +alembic==1.12.0 \ No newline at end of file diff --git a/pilot/base_modules/module_factory.py b/pilot/base_modules/module_factory.py new file mode 100644 index 000000000..e69de29bb