mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
fix(ChatExcel): ChatExcel OutParse Bug Fix
1.ChatExcel OutParse Bug Fix
This commit is contained in:
parent
cad2785d94
commit
af68e9c4c0
4
.gitignore
vendored
4
.gitignore
vendored
@ -149,4 +149,6 @@ pilot/mock_datas/db-gpt-test.db.wal
|
||||
|
||||
logswebserver.log.*
|
||||
.history/*
|
||||
.plugin_env
|
||||
.plugin_env
|
||||
/pilot/meta_data/alembic/versions/*
|
||||
/pilot/meta_data/*.db
|
@ -1,2 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
@ -1,12 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
|
||||
class Agent:
|
||||
"""Agent class for interacting with DB-GPT
|
||||
|
||||
Attributes:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
@ -1,83 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
"""Agent manager for managing GPT agents"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.model.base import Message
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
|
||||
class AgentManager(metaclass=Singleton):
|
||||
"""Agent manager for managing GPT agents"""
|
||||
|
||||
def __init__(self):
|
||||
self.next_key = 0
|
||||
self.agents = {} # key, (task, full_message_history, model)
|
||||
self.cfg = Config()
|
||||
|
||||
"""Agent manager for managing DB-GPT agents
|
||||
In order to compatible auto gpt plugins,
|
||||
we use the same template with it.
|
||||
|
||||
Args: next_keys
|
||||
agents
|
||||
cfg
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_key = 0
|
||||
self.agents = {} # TODO need to define
|
||||
self.cfg = Config()
|
||||
|
||||
# Create new GPT agent
|
||||
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
|
||||
|
||||
def create_agent(self, task: str, prompt: str, model: str) -> tuple[int, str]:
|
||||
"""Create a new agent and return its key
|
||||
|
||||
Args:
|
||||
task: The task to perform
|
||||
prompt: The prompt to use
|
||||
model: The model to use
|
||||
|
||||
Returns:
|
||||
The key of the new agent
|
||||
"""
|
||||
|
||||
def message_agent(self, key: str | int, message: str) -> str:
|
||||
"""Send a message to an agent and return its response
|
||||
|
||||
Args:
|
||||
key: The key of the agent to message
|
||||
message: The message to send to the agent
|
||||
|
||||
Returns:
|
||||
The agent's response
|
||||
"""
|
||||
|
||||
def list_agents(self) -> list[tuple[str | int, str]]:
|
||||
"""Return a list of all agents
|
||||
|
||||
Returns:
|
||||
A list of tuples of the form (key, task)
|
||||
"""
|
||||
|
||||
# Return a list of agent keys and their tasks
|
||||
return [(key, task) for key, (task, _, _) in self.agents.items()]
|
||||
|
||||
def delete_agent(self, key: str | int) -> bool:
|
||||
"""Delete an agent from the agent manager
|
||||
|
||||
Args:
|
||||
key: The key of the agent to delete
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
|
||||
try:
|
||||
del self.agents[int(key)]
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
@ -1,122 +0,0 @@
|
||||
import contextlib
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
from colorama import Fore
|
||||
from regex import regex
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.json_utils.json_fix_general import (
|
||||
add_quotes_to_property_names,
|
||||
balance_braces,
|
||||
fix_invalid_escape,
|
||||
)
|
||||
from pilot.logs import logger
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def fix_and_parse_json(
|
||||
json_to_load: str, try_to_fix_with_gpt: bool = True
|
||||
) -> Dict[Any, Any]:
|
||||
"""Fix and parse JSON string
|
||||
|
||||
Args:
|
||||
json_to_load (str): The JSON string.
|
||||
try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
str or dict[Any, Any]: The parsed JSON.
|
||||
"""
|
||||
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
json_to_load = json_to_load.replace("\t", "")
|
||||
return json.loads(json_to_load)
|
||||
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
json_to_load = correct_json(json_to_load)
|
||||
return json.loads(json_to_load)
|
||||
# Let's do something manually:
|
||||
# sometimes GPT responds with something BEFORE the braces:
|
||||
# "I'm sorry, I don't understand. Please try again."
|
||||
# {"text": "I'm sorry, I don't understand. Please try again.",
|
||||
# "confidence": 0.0}
|
||||
# So let's try to find the first brace and then parse the rest
|
||||
# of the string
|
||||
try:
|
||||
brace_index = json_to_load.index("{")
|
||||
maybe_fixed_json = json_to_load[brace_index:]
|
||||
last_brace_index = maybe_fixed_json.rindex("}")
|
||||
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
|
||||
return json.loads(maybe_fixed_json)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error("参数解析错误", e)
|
||||
|
||||
|
||||
def correct_json(json_to_load: str) -> str:
|
||||
"""
|
||||
Correct common JSON errors.
|
||||
Args:
|
||||
json_to_load (str): The JSON string.
|
||||
"""
|
||||
|
||||
try:
|
||||
logger.debug("json", json_to_load)
|
||||
json.loads(json_to_load)
|
||||
return json_to_load
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("json loads error", e)
|
||||
error_message = str(e)
|
||||
if error_message.startswith("Invalid \\escape"):
|
||||
json_to_load = fix_invalid_escape(json_to_load, error_message)
|
||||
if error_message.startswith(
|
||||
"Expecting property name enclosed in double quotes"
|
||||
):
|
||||
json_to_load = add_quotes_to_property_names(json_to_load)
|
||||
try:
|
||||
json.loads(json_to_load)
|
||||
return json_to_load
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug("json loads error - add quotes", e)
|
||||
error_message = str(e)
|
||||
if balanced_str := balance_braces(json_to_load):
|
||||
return balanced_str
|
||||
return json_to_load
|
||||
|
||||
|
||||
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
if CFG.speak_mode and CFG.debug_mode:
|
||||
say_text(
|
||||
"I have received an invalid JSON response from the OpenAI API. "
|
||||
"Trying to fix it now."
|
||||
)
|
||||
logger.error("Attempting to fix JSON by finding outermost brackets\n")
|
||||
|
||||
try:
|
||||
json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}")
|
||||
json_match = json_pattern.search(json_string)
|
||||
|
||||
if json_match:
|
||||
# Extract the valid JSON object from the string
|
||||
json_string = json_match.group(0)
|
||||
logger.typewriter_log(
|
||||
title="Apparently json was fixed.", title_color=Fore.GREEN
|
||||
)
|
||||
if CFG.speak_mode and CFG.debug_mode:
|
||||
say_text("Apparently json was fixed.")
|
||||
else:
|
||||
return {}
|
||||
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
if CFG.debug_mode:
|
||||
logger.error(f"Error: Invalid JSON: {json_string}\n")
|
||||
if CFG.speak_mode:
|
||||
say_text("Didn't work. I will have to ignore this response then.")
|
||||
logger.error("Error: Invalid JSON, setting it to empty JSON now.\n")
|
||||
json_string = {}
|
||||
|
||||
return fix_and_parse_json(json_string)
|
0
pilot/base_modules/__init__.py
Normal file
0
pilot/base_modules/__init__.py
Normal file
6
pilot/base_modules/agent/__init__.py
Normal file
6
pilot/base_modules/agent/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
|
||||
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||
|
||||
from .commands.command import execute_command, get_command
|
||||
from .commands.generator import PluginPromptGenerator
|
||||
from .commands.disply_type.show_chart_gen import static_message_img_path
|
13
pilot/base_modules/agent/agent.py
Normal file
13
pilot/base_modules/agent/agent.py
Normal file
@ -0,0 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class AgentFacade(ABC):
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
|
||||
|
||||
|
0
pilot/base_modules/agent/commands/__init__.py
Normal file
0
pilot/base_modules/agent/commands/__init__.py
Normal file
61
pilot/base_modules/agent/commands/built_in/audio_text.py
Normal file
61
pilot/base_modules/agent/commands/built_in/audio_text.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""Commands for converting audio to text."""
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@command(
|
||||
"read_audio_from_file",
|
||||
"Convert Audio to text",
|
||||
'"filename": "<filename>"',
|
||||
CFG.huggingface_audio_to_text_model,
|
||||
"Configure huggingface_audio_to_text_model.",
|
||||
)
|
||||
def read_audio_from_file(filename: str) -> str:
|
||||
"""
|
||||
Convert audio to text.
|
||||
|
||||
Args:
|
||||
filename (str): The path to the audio file
|
||||
|
||||
Returns:
|
||||
str: The text from the audio
|
||||
"""
|
||||
with open(filename, "rb") as audio_file:
|
||||
audio = audio_file.read()
|
||||
return read_audio(audio)
|
||||
|
||||
|
||||
def read_audio(audio: bytes) -> str:
|
||||
"""
|
||||
Convert audio to text.
|
||||
|
||||
Args:
|
||||
audio (bytes): The audio to convert
|
||||
|
||||
Returns:
|
||||
str: The text from the audio
|
||||
"""
|
||||
model = CFG.huggingface_audio_to_text_model
|
||||
api_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||
api_token = CFG.huggingface_api_token
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
|
||||
if api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
api_url,
|
||||
headers=headers,
|
||||
data=audio,
|
||||
)
|
||||
|
||||
text = json.loads(response.content.decode("utf-8"))["text"]
|
||||
return f"The audio says: {text}"
|
123
pilot/base_modules/agent/commands/built_in/image_gen.py
Normal file
123
pilot/base_modules/agent/commands/built_in/image_gen.py
Normal file
@ -0,0 +1,123 @@
|
||||
""" Image Generation Module for AutoGPT."""
|
||||
import io
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@command("generate_image", "Generate Image", '"prompt": "<prompt>"', CFG.image_provider)
|
||||
def generate_image(prompt: str, size: int = 256) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
|
||||
|
||||
# HuggingFace
|
||||
if CFG.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename)
|
||||
# SD WebUI
|
||||
elif CFG.image_provider == "sdwebui":
|
||||
return generate_image_with_sd_webui(prompt, filename, size)
|
||||
return "No Image Provider Set"
|
||||
|
||||
|
||||
def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = (
|
||||
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
|
||||
)
|
||||
if CFG.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {CFG.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
json={
|
||||
"inputs": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image.save(filename)
|
||||
|
||||
return f"Saved to disk:{filename}"
|
||||
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
filename: str,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if CFG.sd_webui_auth:
|
||||
username, password = CFG.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(filename)
|
||||
|
||||
return f"Saved to disk:{filename}"
|
152
pilot/base_modules/agent/commands/command.py
Normal file
152
pilot/base_modules/agent/commands/command.py
Normal file
@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from .exception_not_commands import NotCommands
|
||||
from .generator import PluginPromptGenerator
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
def _resolve_pathlike_command_args(command_args):
|
||||
if "directory" in command_args and command_args["directory"] in {"", "/"}:
|
||||
# todo
|
||||
command_args["directory"] = ""
|
||||
else:
|
||||
for pathlike in ["filename", "directory", "clone_path"]:
|
||||
if pathlike in command_args:
|
||||
# todo
|
||||
command_args[pathlike] = ""
|
||||
return command_args
|
||||
|
||||
|
||||
def execute_ai_response_json(
|
||||
prompt: PluginPromptGenerator,
|
||||
ai_response,
|
||||
user_input: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
||||
Args:
|
||||
command_registry:
|
||||
ai_response:
|
||||
prompt:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
from pilot.speech.say import say_text
|
||||
|
||||
cfg = Config()
|
||||
|
||||
command_name, arguments = get_command(ai_response)
|
||||
|
||||
if cfg.speak_mode:
|
||||
say_text(f"I want to execute {command_name}")
|
||||
|
||||
arguments = _resolve_pathlike_command_args(arguments)
|
||||
# Execute command
|
||||
if command_name is not None and command_name.lower().startswith("error"):
|
||||
result = f"Command {command_name} threw the following error: {arguments}"
|
||||
elif command_name == "human_feedback":
|
||||
result = f"Human feedback: {user_input}"
|
||||
else:
|
||||
for plugin in cfg.plugins:
|
||||
if not plugin.can_handle_pre_command():
|
||||
continue
|
||||
command_name, arguments = plugin.pre_command(command_name, arguments)
|
||||
command_result = execute_command(
|
||||
command_name,
|
||||
arguments,
|
||||
prompt,
|
||||
)
|
||||
result = f"{command_result}"
|
||||
return result
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments,
|
||||
prompt: PluginPromptGenerator,
|
||||
):
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
command_name (str): The name of the command to execute
|
||||
arguments (dict): The arguments for the command
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
"""
|
||||
|
||||
cmd = prompt.command_registry.commands.get(command_name)
|
||||
|
||||
# If the command is found, call it with the provided arguments
|
||||
if cmd:
|
||||
try:
|
||||
return cmd(**arguments)
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
# TODO: Change these to take in a file rather than pasted code, if
|
||||
# non-file is given, return instructions "Input should be a python
|
||||
# filepath, write your code to file and try again
|
||||
else:
|
||||
for command in prompt.commands:
|
||||
if (
|
||||
command_name == command["label"].lower()
|
||||
or command_name == command["name"].lower()
|
||||
):
|
||||
try:
|
||||
# 删除非定义参数
|
||||
diff_ags = list(
|
||||
set(arguments.keys()).difference(set(command["args"].keys()))
|
||||
)
|
||||
for arg_name in diff_ags:
|
||||
del arguments[arg_name]
|
||||
print(str(arguments))
|
||||
return command["function"](**arguments)
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
raise NotCommands("非可用命令" + command_name)
|
||||
|
||||
|
||||
def get_command(response_json: Dict):
|
||||
"""Parse the response and return the command name and arguments
|
||||
|
||||
Args:
|
||||
response_json (json): The response from the AI
|
||||
|
||||
Returns:
|
||||
tuple: The command name and arguments
|
||||
|
||||
Raises:
|
||||
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||
|
||||
Exception: If any other error occurs
|
||||
"""
|
||||
try:
|
||||
if "command" not in response_json:
|
||||
return "Error:", "Missing 'command' object in JSON"
|
||||
|
||||
if not isinstance(response_json, dict):
|
||||
return "Error:", f"'response_json' object is not dictionary {response_json}"
|
||||
|
||||
command = response_json["command"]
|
||||
if not isinstance(command, dict):
|
||||
return "Error:", "'command' object is not a dictionary"
|
||||
|
||||
if "name" not in command:
|
||||
return "Error:", "Missing 'name' field in 'command' object"
|
||||
|
||||
command_name = command["name"]
|
||||
|
||||
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||
arguments = command.get("args", {})
|
||||
|
||||
return command_name, arguments
|
||||
except json.decoder.JSONDecodeError:
|
||||
return "Error:", "Invalid JSON"
|
||||
# All other errors, return "Error: + error message"
|
||||
except Exception as e:
|
||||
return "Error:", str(e)
|
156
pilot/base_modules/agent/commands/command_mange.py
Normal file
156
pilot/base_modules/agent/commands/command_mange.py
Normal file
@ -0,0 +1,156 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
# Unique identifier for auto-gpt commands
|
||||
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
|
||||
|
||||
|
||||
class Command:
|
||||
"""A class representing a command.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the command.
|
||||
description (str): A brief description of what the command does.
|
||||
signature (str): The signature of the function that the command executes. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.method = method
|
||||
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||
self.enabled = enabled
|
||||
self.disabled_reason = disabled_reason
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if not self.enabled:
|
||||
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||
return self.method(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}: {self.description}, args: {self.signature}"
|
||||
|
||||
|
||||
class CommandRegistry:
|
||||
"""
|
||||
The CommandRegistry class is a manager for a collection of Command objects.
|
||||
It allows the registration, modification, and retrieval of Command objects,
|
||||
as well as the scanning and loading of command plugins from a specified
|
||||
directory.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.commands = {}
|
||||
|
||||
def _import_module(self, module_name: str) -> Any:
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
def _reload_module(self, module: Any) -> Any:
|
||||
return importlib.reload(module)
|
||||
|
||||
def register(self, cmd: Command) -> None:
|
||||
self.commands[cmd.name] = cmd
|
||||
|
||||
def unregister(self, command_name: str):
|
||||
if command_name in self.commands:
|
||||
del self.commands[command_name]
|
||||
else:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
|
||||
def reload_commands(self) -> None:
|
||||
"""Reloads all loaded command plugins."""
|
||||
for cmd_name in self.commands:
|
||||
cmd = self.commands[cmd_name]
|
||||
module = self._import_module(cmd.__module__)
|
||||
reloaded_module = self._reload_module(module)
|
||||
if hasattr(reloaded_module, "register"):
|
||||
reloaded_module.register(self)
|
||||
|
||||
def get_command(self, name: str) -> Callable[..., Any]:
|
||||
return self.commands[name]
|
||||
|
||||
def call(self, command_name: str, **kwargs) -> Any:
|
||||
if command_name not in self.commands:
|
||||
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||
command = self.commands[command_name]
|
||||
return command(**kwargs)
|
||||
|
||||
def command_prompt(self) -> str:
|
||||
"""
|
||||
Returns a string representation of all registered `Command` objects for use in a prompt
|
||||
"""
|
||||
commands_list = [
|
||||
f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values())
|
||||
]
|
||||
return "\n".join(commands_list)
|
||||
|
||||
def import_commands(self, module_name: str) -> None:
|
||||
"""
|
||||
Imports the specified Python module containing command plugins.
|
||||
|
||||
This method imports the associated module and registers any functions or
|
||||
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
|
||||
as `Command` objects. The registered `Command` objects are then added to the
|
||||
`commands` dictionary of the `CommandRegistry` object.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to import for command plugins.
|
||||
"""
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
# Register decorated functions
|
||||
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
|
||||
attr, AUTO_GPT_COMMAND_IDENTIFIER
|
||||
):
|
||||
self.register(attr.command)
|
||||
# Register command classes
|
||||
elif (
|
||||
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
|
||||
):
|
||||
cmd_instance = attr()
|
||||
self.register(cmd_instance)
|
||||
|
||||
|
||||
def command(
|
||||
name: str,
|
||||
description: str,
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""The command decorator is used to create Command objects from ordinary functions."""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Command:
|
||||
cmd = Command(
|
||||
name=name,
|
||||
description=description,
|
||||
method=func,
|
||||
signature=signature,
|
||||
enabled=enabled,
|
||||
disabled_reason=disabled_reason,
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
wrapper.command = cmd
|
||||
|
||||
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
296
pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
Normal file
296
pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
Normal file
@ -0,0 +1,296 @@
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
import pandas as pd
|
||||
import uuid
|
||||
import os
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
||||
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
||||
|
||||
|
||||
def data_pre_classification(df: DataFrame):
|
||||
## Data pre-classification
|
||||
columns = df.columns.tolist()
|
||||
|
||||
number_columns = []
|
||||
non_numeric_colums = []
|
||||
|
||||
# 收集数据分类小于10个的列
|
||||
non_numeric_colums_value_map = {}
|
||||
numeric_colums_value_map = {}
|
||||
for column_name in columns:
|
||||
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
|
||||
number_columns.append(column_name)
|
||||
unique_values = df[column_name].unique()
|
||||
numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||
else:
|
||||
non_numeric_colums.append(column_name)
|
||||
unique_values = df[column_name].unique()
|
||||
non_numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||
|
||||
sorted_numeric_colums_value_map = dict(
|
||||
sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||
)
|
||||
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
|
||||
|
||||
sorted_colums_value_map = dict(
|
||||
sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||
)
|
||||
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
|
||||
|
||||
# Analyze x-coordinate
|
||||
if len(non_numeric_colums_sort_list) > 0:
|
||||
x_cloumn = non_numeric_colums_sort_list[-1]
|
||||
non_numeric_colums_sort_list.remove(x_cloumn)
|
||||
else:
|
||||
x_cloumn = number_columns[0]
|
||||
numeric_colums_sort_list.remove(x_cloumn)
|
||||
|
||||
# Analyze y-coordinate
|
||||
if len(numeric_colums_sort_list) > 0:
|
||||
y_column = numeric_colums_sort_list[0]
|
||||
numeric_colums_sort_list.remove(y_column)
|
||||
else:
|
||||
raise ValueError("Not enough numeric columns for chart!")
|
||||
|
||||
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
|
||||
|
||||
|
||||
def zh_font_set():
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
|
||||
@command(
|
||||
"response_line_chart",
|
||||
"Line chart display, used to display comparative trend analysis data",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_line_chart:{speak},")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
# ## 复杂折线图实现
|
||||
if len(num_colmns) > 0:
|
||||
num_colmns.append(y)
|
||||
df_melted = pd.melt(
|
||||
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
|
||||
)
|
||||
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
|
||||
else:
|
||||
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
|
||||
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
|
||||
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
|
||||
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
@command(
|
||||
"response_bar_chart",
|
||||
"Histogram, suitable for comparative analysis of multiple target values",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_bar_chart:{speak},")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
|
||||
hue = None
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
if len(non_num_columns) >= 1:
|
||||
hue = non_num_columns[0]
|
||||
|
||||
if len(num_colmns) >= 1:
|
||||
if hue:
|
||||
if len(num_colmns) >= 2:
|
||||
can_use_columns = num_colmns[:2]
|
||||
else:
|
||||
can_use_columns = num_colmns
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
for sub_y_column in can_use_columns:
|
||||
sns.barplot(
|
||||
data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
|
||||
)
|
||||
else:
|
||||
if len(num_colmns) > 5:
|
||||
can_use_columns = num_colmns[:5]
|
||||
else:
|
||||
can_use_columns = num_colmns
|
||||
can_use_columns.append(y)
|
||||
|
||||
df_melted = pd.melt(
|
||||
df,
|
||||
id_vars=x,
|
||||
value_vars=can_use_columns,
|
||||
var_name="line",
|
||||
value_name="Value",
|
||||
)
|
||||
sns.barplot(
|
||||
data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax
|
||||
)
|
||||
else:
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
|
||||
# 设置 y 轴刻度格式为普通数字格式
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
|
||||
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
|
||||
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
@command(
|
||||
"response_pie_chart",
|
||||
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_pie_chart:{speak},")
|
||||
columns = df.columns.tolist()
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
|
||||
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
ax = df.plot(
|
||||
kind="pie",
|
||||
y=columns[1],
|
||||
ax=ax,
|
||||
labels=df[columns[0]].values,
|
||||
startangle=90,
|
||||
autopct="%1.1f%%",
|
||||
)
|
||||
|
||||
plt.axis("equal") # 使饼图为正圆形
|
||||
# plt.title(columns[0])
|
||||
|
||||
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
|
||||
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
|
||||
return html_img
|
@ -0,0 +1,24 @@
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
|
||||
|
||||
@command(
|
||||
"response_table",
|
||||
"Table display, suitable for display with many display columns or non-numeric columns",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_table(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_table:{speak}")
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||
return view_text
|
@ -0,0 +1,39 @@
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
|
||||
|
||||
@command(
|
||||
"response_data_text",
|
||||
"Text display, the default display method, suitable for single-line or simple content display",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
)
|
||||
def response_data_text(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_data_text:{speak}")
|
||||
data = df.values
|
||||
|
||||
row_size = data.shape[0]
|
||||
value_str = ""
|
||||
text_info = ""
|
||||
if row_size > 1:
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||
elif row_size == 1:
|
||||
row = data[0]
|
||||
for value in row:
|
||||
if value_str:
|
||||
value_str = value_str + f", ** {value} **"
|
||||
else:
|
||||
value_str = f" ** {value} **"
|
||||
text_info = f"{speak}: {value_str}"
|
||||
else:
|
||||
text_info = f"##### {speak}: _没有找到可用的数据_"
|
||||
return text_info
|
@ -0,0 +1,4 @@
|
||||
class NotCommands(Exception):
|
||||
def __init__(self, message):
|
||||
super().__init__(message)
|
||||
self.message = message
|
158
pilot/base_modules/agent/commands/generator.py
Normal file
158
pilot/base_modules/agent/commands/generator.py
Normal file
@ -0,0 +1,158 @@
|
||||
""" A module for generating custom prompt strings."""
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
class PluginPromptGenerator:
|
||||
"""
|
||||
A class for generating custom prompt strings based on constraints, commands,
|
||||
resources, and performance evaluations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the PromptGenerator object with empty lists of constraints,
|
||||
commands, resources, and performance evaluations.
|
||||
"""
|
||||
self.constraints = []
|
||||
self.commands = []
|
||||
self.resources = []
|
||||
self.performance_evaluation = []
|
||||
self.goals = []
|
||||
self.command_registry = None
|
||||
self.name = "Bob"
|
||||
self.role = "AI"
|
||||
self.response_format = {
|
||||
"thoughts": {
|
||||
"text": "thought",
|
||||
"reasoning": "reasoning",
|
||||
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||
"criticism": "constructive self-criticism",
|
||||
"speak": "thoughts summary to say to user",
|
||||
},
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""
|
||||
Add a constraint to the constraints list.
|
||||
|
||||
Args:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
command_label: str,
|
||||
command_name: str,
|
||||
args=None,
|
||||
function: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a command to the commands list with a label, name, and optional arguments.
|
||||
|
||||
Args:
|
||||
command_label (str): The label of the command.
|
||||
command_name (str): The name of the command.
|
||||
args (dict, optional): A dictionary containing argument names and their
|
||||
values. Defaults to None.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
if args is None:
|
||||
args = {}
|
||||
|
||||
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
|
||||
|
||||
command = {
|
||||
"label": command_label,
|
||||
"name": command_name,
|
||||
"args": command_args,
|
||||
"function": function,
|
||||
}
|
||||
|
||||
self.commands.append(command)
|
||||
|
||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate a formatted string representation of a command.
|
||||
|
||||
Args:
|
||||
command (dict): A dictionary containing command information.
|
||||
|
||||
Returns:
|
||||
str: The formatted command string.
|
||||
"""
|
||||
args_string = ", ".join(
|
||||
f'"{key}": "{value}"' for key, value in command["args"].items()
|
||||
)
|
||||
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
Add a resource to the resources list.
|
||||
|
||||
Args:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
self.resources.append(resource)
|
||||
|
||||
def add_performance_evaluation(self, evaluation: str) -> None:
|
||||
"""
|
||||
Add a performance evaluation item to the performance_evaluation list.
|
||||
|
||||
Args:
|
||||
evaluation (str): The evaluation item to be added.
|
||||
"""
|
||||
self.performance_evaluation.append(evaluation)
|
||||
|
||||
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
|
||||
"""
|
||||
Generate a numbered list from given items based on the item_type.
|
||||
|
||||
Args:
|
||||
items (list): A list of items to be numbered.
|
||||
item_type (str, optional): The type of items in the list.
|
||||
Defaults to 'list'.
|
||||
|
||||
Returns:
|
||||
str: The formatted numbered list.
|
||||
"""
|
||||
if item_type == "command":
|
||||
command_strings = []
|
||||
if self.command_registry:
|
||||
command_strings += [
|
||||
str(item)
|
||||
for item in self.command_registry.commands.values()
|
||||
if item.enabled
|
||||
]
|
||||
# terminate command is added manually
|
||||
command_strings += [self._generate_command_string(item) for item in items]
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||
else:
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||
|
||||
def generate_commands_string(self) -> str:
|
||||
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
||||
|
||||
def generate_prompt_string(self) -> str:
|
||||
"""
|
||||
Generate a prompt string based on the constraints, commands, resources,
|
||||
and performance evaluations.
|
||||
|
||||
Returns:
|
||||
str: The generated prompt string.
|
||||
"""
|
||||
formatted_response_format = json.dumps(self.response_format, indent=4)
|
||||
return (
|
||||
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
|
||||
"Commands:\n"
|
||||
f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n"
|
||||
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
|
||||
"Performance Evaluation:\n"
|
||||
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
|
||||
"You should only respond in JSON format as described below and ensure the"
|
||||
"response can be parsed by Python json.loads \nResponse"
|
||||
f" Format: \n{formatted_response_format}"
|
||||
)
|
10
pilot/base_modules/agent/commands/times.py
Normal file
10
pilot/base_modules/agent/commands/times.py
Normal file
@ -0,0 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_datetime() -> str:
|
||||
"""Return the current date and time
|
||||
|
||||
Returns:
|
||||
str: The current date and time
|
||||
"""
|
||||
return "Current date and time: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
26
pilot/base_modules/agent/controller.py
Normal file
26
pilot/base_modules/agent/controller.py
Normal file
@ -0,0 +1,26 @@
|
||||
import json
|
||||
import time
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
)
|
||||
|
||||
from typing import List
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
)
|
||||
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
|
||||
|
||||
|
||||
@router.get("/v1/mange/agent/list", response_model=Result[str])
|
||||
async def get_agent_list():
|
||||
logger.info(f"get_agent_list!")
|
||||
|
||||
return Result.succ(None)
|
136
pilot/base_modules/agent/db/my_plugin_db.py
Normal file
136
pilot/base_modules/agent/db/my_plugin_db.py
Normal file
@ -0,0 +1,136 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from pilot.base_modules.meta_data.meta_data import Base
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
|
||||
|
||||
class MyPluginEntity(Base):
|
||||
__tablename__ = 'my_plugin'
|
||||
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
tenant = Column(String, nullable=True, comment="user's tenant")
|
||||
user_code = Column(String, nullable=True, comment="user code")
|
||||
user_name = Column(String, nullable=True, comment="user name")
|
||||
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||
type = Column(String, comment="plugin type")
|
||||
version = Column(String, comment="plugin version")
|
||||
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
|
||||
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
|
||||
__table_args__ = (
|
||||
UniqueConstraint('name', name="uk_name"),
|
||||
)
|
||||
|
||||
|
||||
class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="dbgpt", orm_base=Base, db_engine =engine , session= session
|
||||
)
|
||||
|
||||
def add(self, engity: MyPluginEntity):
|
||||
session = self.Session()
|
||||
my_plugin = MyPluginEntity(
|
||||
tenant=engity.tenant,
|
||||
user_code=engity.user_code,
|
||||
user_name=engity.user_name,
|
||||
name=engity.name,
|
||||
type=engity.type,
|
||||
version=engity.version,
|
||||
use_count=engity.use_count or 0,
|
||||
succ_count=engity.succ_count or 0,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
session.add(my_plugin)
|
||||
session.commit()
|
||||
id = my_plugin.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def update(self, entity: MyPluginEntity):
|
||||
session = self.Session()
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
|
||||
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
|
||||
session = self.Session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.name == query.name
|
||||
)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.tenant == query.tenant
|
||||
)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.type == query.type
|
||||
)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_code == query.user_code
|
||||
)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_name == query.user_name
|
||||
)
|
||||
|
||||
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
|
||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def count(self, query: MyPluginEntity):
|
||||
session = self.Session()
|
||||
my_plugins = session.query(func.count(MyPluginEntity.id))
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.name == query.name
|
||||
)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.type == query.type
|
||||
)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.tenant == query.tenant
|
||||
)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_code == query.user_code
|
||||
)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_name == query.user_name
|
||||
)
|
||||
count = my_plugins.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
|
||||
def delete(self, plugin_id: int):
|
||||
session = self.Session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
query = MyPluginEntity(id=plugin_id)
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.id == query.id
|
||||
)
|
||||
my_plugins.delete()
|
||||
session.commit()
|
||||
session.close()
|
136
pilot/base_modules/agent/db/plugin_hub_db.py
Normal file
136
pilot/base_modules/agent/db/plugin_hub_db.py
Normal file
@ -0,0 +1,136 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from pilot.base_modules.meta_data.meta_data import Base
|
||||
|
||||
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
|
||||
|
||||
class PluginHubEntity(Base):
|
||||
__tablename__ = 'plugin_hub'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||
author = Column(String, nullable=True, comment="plugin author")
|
||||
email = Column(String, nullable=True, comment="plugin author email")
|
||||
type = Column(String, comment="plugin type")
|
||||
version = Column(String, comment="plugin version")
|
||||
storage_channel = Column(String, comment="plugin storage channel")
|
||||
storage_url = Column(String, comment="plugin download url")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
|
||||
installed = Column(Boolean, default=False, comment="plugin already installed")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('name', name="uk_name"),
|
||||
Index('idx_q_type', 'type'),
|
||||
)
|
||||
|
||||
|
||||
class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def add(self, engity: PluginHubEntity):
|
||||
session = self.Session()
|
||||
plugin_hub = PluginHubEntity(
|
||||
name=engity.name,
|
||||
author=engity.author,
|
||||
email=engity.email,
|
||||
type=engity.type,
|
||||
version=engity.version,
|
||||
storage_channel=engity.storage_channel,
|
||||
storage_url=engity.storage_url,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
session.add(plugin_hub)
|
||||
session.commit()
|
||||
id = plugin_hub.id
|
||||
session.close()
|
||||
return id
|
||||
|
||||
def update(self, entity: PluginHubEntity):
|
||||
session = self.Session()
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
|
||||
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
|
||||
session = self.Session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.name == query.name
|
||||
)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.type == query.type
|
||||
)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.author == query.author
|
||||
)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
)
|
||||
|
||||
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
|
||||
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||
session = self.Session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.name == name
|
||||
)
|
||||
result = plugin_hubs.get(1)
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def count(self, query: PluginHubEntity):
|
||||
session = self.Session()
|
||||
plugin_hubs = session.query(func.count(PluginHubEntity.id))
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.name == query.name
|
||||
)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.type == query.type
|
||||
)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.author == query.author
|
||||
)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
)
|
||||
count = plugin_hubs.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
def delete(self, plugin_id: int):
|
||||
session = self.Session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
query = PluginHubEntity(id=plugin_id)
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.id == query.id
|
||||
)
|
||||
plugin_hubs.delete()
|
||||
session.commit()
|
||||
session.close()
|
0
pilot/base_modules/agent/hub/__init__.py
Normal file
0
pilot/base_modules/agent/hub/__init__.py
Normal file
69
pilot/base_modules/agent/hub/agent_hub.py
Normal file
69
pilot/base_modules/agent/hub/agent_hub.py
Normal file
@ -0,0 +1,69 @@
|
||||
import logging
|
||||
|
||||
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
||||
from .schema import PluginStorageType
|
||||
|
||||
logger = logging.getLogger("agent_hub")
|
||||
|
||||
|
||||
class AgentHub:
|
||||
def __init__(self) -> None:
|
||||
self.hub_dao = PluginHubDao()
|
||||
self.my_lugin_dao = MyPluginDao()
|
||||
|
||||
def install_plugin(self, plugin_name: str, user_name: str = None):
|
||||
logger.info(f"install_plugin {plugin_name}")
|
||||
|
||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||
if plugin_entity:
|
||||
if plugin_entity.storage_channel == PluginStorageType.Git.value:
|
||||
try:
|
||||
self.__download_from_git(plugin_name, plugin_entity.storage_url)
|
||||
self.load_plugin(plugin_name)
|
||||
|
||||
# add to my plugins and edit hub status
|
||||
plugin_entity.installed = True
|
||||
|
||||
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
||||
if not user_name:
|
||||
# TODO use user
|
||||
my_plugin_entity.user_code = ""
|
||||
my_plugin_entity.user_name = user_name
|
||||
my_plugin_entity.tenant = ""
|
||||
|
||||
with self.hub_dao.Session() as session:
|
||||
try:
|
||||
session.add(my_plugin_entity)
|
||||
session.merge(plugin_entity)
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
except Exception as e:
|
||||
logger.error("install pluguin exception!", e)
|
||||
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
|
||||
else:
|
||||
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
||||
|
||||
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
|
||||
my_plugin_entity = MyPluginEntity()
|
||||
my_plugin_entity.name = hub_plugin.name
|
||||
my_plugin_entity.type = hub_plugin.type
|
||||
my_plugin_entity.version = hub_plugin.version
|
||||
return my_plugin_entity
|
||||
|
||||
def __download_from_git(self, plugin_name, url):
|
||||
pass
|
||||
|
||||
def load_plugin(self, plugin_name):
|
||||
|
||||
pass
|
||||
|
||||
def get_my_plugin(self, user: str):
|
||||
pass
|
||||
|
||||
def uninstall_plugin(self):
|
||||
pass
|
6
pilot/base_modules/agent/hub/schema.py
Normal file
6
pilot/base_modules/agent/hub/schema.py
Normal file
@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PluginStorageType(Enum):
|
||||
Git = "git"
|
||||
Oss = "oss"
|
182
pilot/base_modules/agent/plugins.py
Normal file
182
pilot/base_modules/agent/plugins.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""加载组件"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
import zipfile
|
||||
import requests
|
||||
import threading
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
from zipimport import zipimporter
|
||||
|
||||
import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.logs import logger
|
||||
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
"""
|
||||
Loader zip plugin file. Native support Auto_gpt_plugin
|
||||
|
||||
Args:
|
||||
zip_path (str): Path to the zipfile.
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list[str]: The list of module names found or empty list if none were found.
|
||||
"""
|
||||
result = []
|
||||
with zipfile.ZipFile(zip_path, "r") as zfile:
|
||||
for name in zfile.namelist():
|
||||
if name.endswith("__init__.py") and not name.startswith("__MACOSX"):
|
||||
logger.debug(f"Found module '{name}' in the zipfile at: {name}")
|
||||
result.append(name)
|
||||
if len(result) == 0:
|
||||
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
|
||||
return result
|
||||
|
||||
|
||||
def write_dict_to_json_file(data: dict, file_path: str) -> None:
|
||||
"""
|
||||
Write a dictionary to a JSON file.
|
||||
Args:
|
||||
data (dict): Dictionary to write.
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, indent=4)
|
||||
|
||||
|
||||
def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||
"""
|
||||
Create a directory if it does not exist.
|
||||
Args:
|
||||
directory_path (str): Path to the directory.
|
||||
Returns:
|
||||
bool: True if the directory was created, else False.
|
||||
"""
|
||||
if not os.path.exists(directory_path):
|
||||
try:
|
||||
os.makedirs(directory_path)
|
||||
logger.debug(f"Created directory: {directory_path}")
|
||||
return True
|
||||
except OSError as e:
|
||||
logger.warn(f"Error creating directory {directory_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"Directory {directory_path} already exists")
|
||||
return True
|
||||
|
||||
|
||||
def load_native_plugins(cfg: Config):
|
||||
if not cfg.plugins_auto_load:
|
||||
print("not auto load_native_plugins")
|
||||
return
|
||||
|
||||
def load_from_git(cfg: Config):
|
||||
print("async load_native_plugins")
|
||||
branch_name = cfg.plugins_git_branch
|
||||
native_plugin_repo = "DB-GPT-Plugins"
|
||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||
try:
|
||||
session = requests.Session()
|
||||
response = session.get(
|
||||
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={"Authorization": "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"},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
files = glob.glob(
|
||||
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
print("save file")
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
else:
|
||||
print("get file faild,response code:", response.status_code)
|
||||
except Exception as e:
|
||||
print("load plugin from git exception!" + str(e))
|
||||
|
||||
t = threading.Thread(target=load_from_git, args=(cfg,))
|
||||
t.start()
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
cfg (Config): Config instance including plugins config
|
||||
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Path]]: List of plugins.
|
||||
"""
|
||||
loaded_plugins = []
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
# Generic plugins
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
|
||||
for plugin in plugins_path_path.glob("*.zip"):
|
||||
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
||||
for module in moduleList:
|
||||
plugin = Path(plugin)
|
||||
module = Path(module)
|
||||
logger.debug(f"Plugin: {plugin} Module: {module}")
|
||||
zipped_package = zipimporter(str(plugin))
|
||||
zipped_module = zipped_package.load_module(str(module.parent))
|
||||
for key in dir(zipped_module):
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
a_module = getattr(zipped_module, key)
|
||||
a_keys = dir(a_module)
|
||||
if (
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
):
|
||||
loaded_plugins.append(a_module())
|
||||
|
||||
if loaded_plugins:
|
||||
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||
for plugin in loaded_plugins:
|
||||
logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}")
|
||||
return loaded_plugins
|
||||
|
||||
|
||||
def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
||||
"""Check if the plugin is in the allowlist or denylist.
|
||||
|
||||
Args:
|
||||
plugin_name (str): Name of the plugin.
|
||||
cfg (Config): Config object.
|
||||
|
||||
Returns:
|
||||
True or False
|
||||
"""
|
||||
logger.debug(f"Checking if plugin {plugin_name} should be loaded")
|
||||
if plugin_name in cfg.plugins_denylist:
|
||||
logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.")
|
||||
return False
|
||||
if plugin_name in cfg.plugins_allowlist:
|
||||
logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.")
|
||||
return True
|
||||
ack = input(
|
||||
f"WARNING: Plugin {plugin_name} found. But not in the"
|
||||
f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): "
|
||||
)
|
||||
return ack.lower() == cfg.authorise_key
|
2
pilot/base_modules/agent/requirement.txt
Normal file
2
pilot/base_modules/agent/requirement.txt
Normal file
@ -0,0 +1,2 @@
|
||||
flask_sqlalchemy==3.0.5
|
||||
flask==2.3.2
|
0
pilot/base_modules/base.py
Normal file
0
pilot/base_modules/base.py
Normal file
7
pilot/base_modules/mange_base_api.py
Normal file
7
pilot/base_modules/mange_base_api.py
Normal file
@ -0,0 +1,7 @@
|
||||
class ModuleMangeApi:
|
||||
|
||||
def module_name(self):
|
||||
pass
|
||||
|
||||
def register(self):
|
||||
pass
|
0
pilot/base_modules/meta_data/__init__.py
Normal file
0
pilot/base_modules/meta_data/__init__.py
Normal file
21
pilot/base_modules/meta_data/base_dao.py
Normal file
21
pilot/base_modules/meta_data/base_dao.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import TypeVar, Generic, List, Any
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class BaseDao(Generic[T]):
|
||||
def __init__(
|
||||
self, orm_base=None, database: str = None, db_engine: Any = None, session: Any = None,
|
||||
) -> None:
|
||||
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
||||
self._orm_base = orm_base
|
||||
self._database = database
|
||||
|
||||
self._db_engine = db_engine
|
||||
self._session = session
|
||||
|
||||
@property
|
||||
def Session(self):
|
||||
if not self._session:
|
||||
self._session = sessionmaker(bind=self.db_engine)
|
||||
return self._session
|
109
pilot/base_modules/meta_data/meta_data.py
Normal file
109
pilot/base_modules/meta_data/meta_data.py
Normal file
@ -0,0 +1,109 @@
|
||||
import uuid
|
||||
import os
|
||||
import duckdb
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from flask import Flask
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate,upgrade
|
||||
from flask.cli import with_appcontext
|
||||
import subprocess
|
||||
|
||||
from sqlalchemy import create_engine,DateTime, String, func, MetaData
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from alembic import context, command
|
||||
from alembic.config import Config
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "meta_data")
|
||||
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
|
||||
db_path = default_db_path + "/dbgpt.db"
|
||||
connection = sqlite3.connect(db_path)
|
||||
engine = create_engine(f'sqlite:///{db_path}')
|
||||
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = Session()
|
||||
|
||||
Base = declarative_base(bind=engine)
|
||||
|
||||
# Base.metadata.create_all()
|
||||
|
||||
# 创建Alembic配置对象
|
||||
|
||||
alembic_ini_path = default_db_path + "/alembic.ini"
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
|
||||
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
|
||||
|
||||
os.makedirs(default_db_path + "/alembic", exist_ok=True)
|
||||
alembic_cfg.set_main_option('script_location', default_db_path + "/alembic")
|
||||
|
||||
# 将模型和会话传递给Alembic配置
|
||||
alembic_cfg.attributes['target_metadata'] = Base.metadata
|
||||
alembic_cfg.attributes['session'] = session
|
||||
|
||||
|
||||
# # 创建表
|
||||
# Base.metadata.create_all(engine)
|
||||
#
|
||||
# # 删除表
|
||||
# Base.metadata.drop_all(engine)
|
||||
|
||||
|
||||
# app = Flask(__name__)
|
||||
# default_db_path = os.path.join(os.getcwd(), "meta_data")
|
||||
# duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/dbgpt.db")
|
||||
# app.config['SQLALCHEMY_DATABASE_URI'] = f'duckdb://{duckdb_path}'
|
||||
# app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||||
# db = SQLAlchemy(app)
|
||||
# migrate = Migrate(app, db)
|
||||
#
|
||||
# # 设置FLASK_APP环境变量
|
||||
# import os
|
||||
# os.environ['FLASK_APP'] = 'server.dbgpt_server.py'
|
||||
#
|
||||
# @app.cli.command("db_init")
|
||||
# @with_appcontext
|
||||
# def db_init():
|
||||
# subprocess.run(["flask", "db", "init"])
|
||||
#
|
||||
# @app.cli.command("db_migrate")
|
||||
# @with_appcontext
|
||||
# def db_migrate():
|
||||
# subprocess.run(["flask", "db", "migrate"])
|
||||
#
|
||||
# @app.cli.command("db_upgrade")
|
||||
# @with_appcontext
|
||||
# def db_upgrade():
|
||||
# subprocess.run(["flask", "db", "upgrade"])
|
||||
#
|
||||
|
||||
|
||||
|
||||
def ddl_init_and_upgrade():
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
# 生成并应用迁移脚本
|
||||
# command.upgrade(alembic_cfg, 'head')
|
||||
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
|
||||
with engine.connect() as connection:
|
||||
alembic_cfg.attributes['connection'] = connection
|
||||
command.revision(alembic_cfg, "test", True)
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
# alembic_cfg.attributes['connection'] = engine.connect()
|
||||
# command.upgrade(alembic_cfg, 'head')
|
||||
|
||||
# with app.app_context():
|
||||
# db_init()
|
||||
# db_migrate()
|
||||
# db_upgrade()
|
||||
|
||||
|
1
pilot/base_modules/meta_data/requirement.txt
Normal file
1
pilot/base_modules/meta_data/requirement.txt
Normal file
@ -0,0 +1 @@
|
||||
alembic==1.12.0
|
0
pilot/base_modules/module_factory.py
Normal file
0
pilot/base_modules/module_factory.py
Normal file
Loading…
Reference in New Issue
Block a user