fix(ChatExcel): ChatExcel OutParse Bug Fix

1.ChatExcel OutParse Bug Fix
This commit is contained in:
yhjun1026 2023-09-14 20:47:10 +08:00
parent cad2785d94
commit af68e9c4c0
36 changed files with 1740 additions and 220 deletions

4
.gitignore vendored
View File

@ -149,4 +149,6 @@ pilot/mock_datas/db-gpt-test.db.wal
logswebserver.log.*
.history/*
.plugin_env
.plugin_env
/pilot/meta_data/alembic/versions/*
/pilot/meta_data/*.db

View File

@ -1,2 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-

View File

@ -1,12 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
class Agent:
"""Agent class for interacting with DB-GPT
Attributes:
"""
def __init__(self) -> None:
pass

View File

@ -1,83 +0,0 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
"""Agent manager for managing GPT agents"""
from __future__ import annotations
from pilot.configs.config import Config
from pilot.model.base import Message
from pilot.singleton import Singleton
class AgentManager(metaclass=Singleton):
"""Agent manager for managing GPT agents"""
def __init__(self):
self.next_key = 0
self.agents = {} # key, (task, full_message_history, model)
self.cfg = Config()
"""Agent manager for managing DB-GPT agents
In order to compatible auto gpt plugins,
we use the same template with it.
Args: next_keys
agents
cfg
"""
def __init__(self) -> None:
self.next_key = 0
self.agents = {} # TODO need to define
self.cfg = Config()
# Create new GPT agent
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
def create_agent(self, task: str, prompt: str, model: str) -> tuple[int, str]:
"""Create a new agent and return its key
Args:
task: The task to perform
prompt: The prompt to use
model: The model to use
Returns:
The key of the new agent
"""
def message_agent(self, key: str | int, message: str) -> str:
"""Send a message to an agent and return its response
Args:
key: The key of the agent to message
message: The message to send to the agent
Returns:
The agent's response
"""
def list_agents(self) -> list[tuple[str | int, str]]:
"""Return a list of all agents
Returns:
A list of tuples of the form (key, task)
"""
# Return a list of agent keys and their tasks
return [(key, task) for key, (task, _, _) in self.agents.items()]
def delete_agent(self, key: str | int) -> bool:
"""Delete an agent from the agent manager
Args:
key: The key of the agent to delete
Returns:
True if successful, False otherwise
"""
try:
del self.agents[int(key)]
return True
except KeyError:
return False

View File

@ -1,122 +0,0 @@
import contextlib
import json
from typing import Any, Dict
from colorama import Fore
from regex import regex
from pilot.configs.config import Config
from pilot.json_utils.json_fix_general import (
add_quotes_to_property_names,
balance_braces,
fix_invalid_escape,
)
from pilot.logs import logger
CFG = Config()
def fix_and_parse_json(
json_to_load: str, try_to_fix_with_gpt: bool = True
) -> Dict[Any, Any]:
"""Fix and parse JSON string
Args:
json_to_load (str): The JSON string.
try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT.
Defaults to True.
Returns:
str or dict[Any, Any]: The parsed JSON.
"""
with contextlib.suppress(json.JSONDecodeError):
json_to_load = json_to_load.replace("\t", "")
return json.loads(json_to_load)
with contextlib.suppress(json.JSONDecodeError):
json_to_load = correct_json(json_to_load)
return json.loads(json_to_load)
# Let's do something manually:
# sometimes GPT responds with something BEFORE the braces:
# "I'm sorry, I don't understand. Please try again."
# {"text": "I'm sorry, I don't understand. Please try again.",
# "confidence": 0.0}
# So let's try to find the first brace and then parse the rest
# of the string
try:
brace_index = json_to_load.index("{")
maybe_fixed_json = json_to_load[brace_index:]
last_brace_index = maybe_fixed_json.rindex("}")
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
return json.loads(maybe_fixed_json)
except (json.JSONDecodeError, ValueError) as e:
logger.error("参数解析错误", e)
def correct_json(json_to_load: str) -> str:
"""
Correct common JSON errors.
Args:
json_to_load (str): The JSON string.
"""
try:
logger.debug("json", json_to_load)
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error", e)
error_message = str(e)
if error_message.startswith("Invalid \\escape"):
json_to_load = fix_invalid_escape(json_to_load, error_message)
if error_message.startswith(
"Expecting property name enclosed in double quotes"
):
json_to_load = add_quotes_to_property_names(json_to_load)
try:
json.loads(json_to_load)
return json_to_load
except json.JSONDecodeError as e:
logger.debug("json loads error - add quotes", e)
error_message = str(e)
if balanced_str := balance_braces(json_to_load):
return balanced_str
return json_to_load
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
from pilot.speech.say import say_text
if CFG.speak_mode and CFG.debug_mode:
say_text(
"I have received an invalid JSON response from the OpenAI API. "
"Trying to fix it now."
)
logger.error("Attempting to fix JSON by finding outermost brackets\n")
try:
json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}")
json_match = json_pattern.search(json_string)
if json_match:
# Extract the valid JSON object from the string
json_string = json_match.group(0)
logger.typewriter_log(
title="Apparently json was fixed.", title_color=Fore.GREEN
)
if CFG.speak_mode and CFG.debug_mode:
say_text("Apparently json was fixed.")
else:
return {}
except (json.JSONDecodeError, ValueError):
if CFG.debug_mode:
logger.error(f"Error: Invalid JSON: {json_string}\n")
if CFG.speak_mode:
say_text("Didn't work. I will have to ignore this response then.")
logger.error("Error: Invalid JSON, setting it to empty JSON now.\n")
json_string = {}
return fix_and_parse_json(json_string)

View File

View File

@ -0,0 +1,6 @@
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
from .commands.command import execute_command, get_command
from .commands.generator import PluginPromptGenerator
from .commands.disply_type.show_chart_gen import static_message_img_path

View File

@ -0,0 +1,13 @@
from abc import ABC, abstractmethod
class AgentFacade(ABC):
def __init__(self) -> None:
self.model = None

View File

@ -0,0 +1,61 @@
"""Commands for converting audio to text."""
import json
import requests
from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
CFG = Config()
@command(
"read_audio_from_file",
"Convert Audio to text",
'"filename": "<filename>"',
CFG.huggingface_audio_to_text_model,
"Configure huggingface_audio_to_text_model.",
)
def read_audio_from_file(filename: str) -> str:
"""
Convert audio to text.
Args:
filename (str): The path to the audio file
Returns:
str: The text from the audio
"""
with open(filename, "rb") as audio_file:
audio = audio_file.read()
return read_audio(audio)
def read_audio(audio: bytes) -> str:
"""
Convert audio to text.
Args:
audio (bytes): The audio to convert
Returns:
str: The text from the audio
"""
model = CFG.huggingface_audio_to_text_model
api_url = f"https://api-inference.huggingface.co/models/{model}"
api_token = CFG.huggingface_api_token
headers = {"Authorization": f"Bearer {api_token}"}
if api_token is None:
raise ValueError(
"You need to set your Hugging Face API token in the config file."
)
response = requests.post(
api_url,
headers=headers,
data=audio,
)
text = json.loads(response.content.decode("utf-8"))["text"]
return f"The audio says: {text}"

View File

@ -0,0 +1,123 @@
""" Image Generation Module for AutoGPT."""
import io
import uuid
from base64 import b64decode
import requests
from PIL import Image
from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
from pilot.logs import logger
CFG = Config()
@command("generate_image", "Generate Image", '"prompt": "<prompt>"', CFG.image_provider)
def generate_image(prompt: str, size: int = 256) -> str:
"""Generate an image from a prompt.
Args:
prompt (str): The prompt to use
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
Returns:
str: The filename of the image
"""
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
# HuggingFace
if CFG.image_provider == "huggingface":
return generate_image_with_hf(prompt, filename)
# SD WebUI
elif CFG.image_provider == "sdwebui":
return generate_image_with_sd_webui(prompt, filename, size)
return "No Image Provider Set"
def generate_image_with_hf(prompt: str, filename: str) -> str:
"""Generate an image with HuggingFace's API.
Args:
prompt (str): The prompt to use
filename (str): The filename to save the image to
Returns:
str: The filename of the image
"""
API_URL = (
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
)
if CFG.huggingface_api_token is None:
raise ValueError(
"You need to set your Hugging Face API token in the config file."
)
headers = {
"Authorization": f"Bearer {CFG.huggingface_api_token}",
"X-Use-Cache": "false",
}
response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)
image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")
image.save(filename)
return f"Saved to disk:{filename}"
def generate_image_with_sd_webui(
prompt: str,
filename: str,
size: int = 512,
negative_prompt: str = "",
extra: dict = {},
) -> str:
"""Generate an image with Stable Diffusion webui.
Args:
prompt (str): The prompt to use
filename (str): The filename to save the image to
size (int, optional): The size of the image. Defaults to 256.
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
Returns:
str: The filename of the image
"""
# Create a session and set the basic auth if needed
s = requests.Session()
if CFG.sd_webui_auth:
username, password = CFG.sd_webui_auth.split(":")
s.auth = (username, password or "")
# Generate the images
response = requests.post(
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
json={
"prompt": prompt,
"negative_prompt": negative_prompt,
"sampler_index": "DDIM",
"steps": 20,
"cfg_scale": 7.0,
"width": size,
"height": size,
"n_iter": 1,
**extra,
},
)
logger.info(f"Image Generated for prompt:{prompt}")
# Save the image to disk
response = response.json()
b64 = b64decode(response["images"][0].split(",", 1)[0])
image = Image.open(io.BytesIO(b64))
image.save(filename)
return f"Saved to disk:{filename}"

View File

@ -0,0 +1,152 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
from typing import Dict
from .exception_not_commands import NotCommands
from .generator import PluginPromptGenerator
from pilot.configs.config import Config
def _resolve_pathlike_command_args(command_args):
if "directory" in command_args and command_args["directory"] in {"", "/"}:
# todo
command_args["directory"] = ""
else:
for pathlike in ["filename", "directory", "clone_path"]:
if pathlike in command_args:
# todo
command_args[pathlike] = ""
return command_args
def execute_ai_response_json(
prompt: PluginPromptGenerator,
ai_response,
user_input: str = None,
) -> str:
"""
Args:
command_registry:
ai_response:
prompt:
Returns:
"""
from pilot.speech.say import say_text
cfg = Config()
command_name, arguments = get_command(ai_response)
if cfg.speak_mode:
say_text(f"I want to execute {command_name}")
arguments = _resolve_pathlike_command_args(arguments)
# Execute command
if command_name is not None and command_name.lower().startswith("error"):
result = f"Command {command_name} threw the following error: {arguments}"
elif command_name == "human_feedback":
result = f"Human feedback: {user_input}"
else:
for plugin in cfg.plugins:
if not plugin.can_handle_pre_command():
continue
command_name, arguments = plugin.pre_command(command_name, arguments)
command_result = execute_command(
command_name,
arguments,
prompt,
)
result = f"{command_result}"
return result
def execute_command(
command_name: str,
arguments,
prompt: PluginPromptGenerator,
):
"""Execute the command and return the result
Args:
command_name (str): The name of the command to execute
arguments (dict): The arguments for the command
Returns:
str: The result of the command
"""
cmd = prompt.command_registry.commands.get(command_name)
# If the command is found, call it with the provided arguments
if cmd:
try:
return cmd(**arguments)
except Exception as e:
return f"Error: {str(e)}"
# TODO: Change these to take in a file rather than pasted code, if
# non-file is given, return instructions "Input should be a python
# filepath, write your code to file and try again
else:
for command in prompt.commands:
if (
command_name == command["label"].lower()
or command_name == command["name"].lower()
):
try:
# 删除非定义参数
diff_ags = list(
set(arguments.keys()).difference(set(command["args"].keys()))
)
for arg_name in diff_ags:
del arguments[arg_name]
print(str(arguments))
return command["function"](**arguments)
except Exception as e:
return f"Error: {str(e)}"
raise NotCommands("非可用命令" + command_name)
def get_command(response_json: Dict):
"""Parse the response and return the command name and arguments
Args:
response_json (json): The response from the AI
Returns:
tuple: The command name and arguments
Raises:
json.decoder.JSONDecodeError: If the response is not valid JSON
Exception: If any other error occurs
"""
try:
if "command" not in response_json:
return "Error:", "Missing 'command' object in JSON"
if not isinstance(response_json, dict):
return "Error:", f"'response_json' object is not dictionary {response_json}"
command = response_json["command"]
if not isinstance(command, dict):
return "Error:", "'command' object is not a dictionary"
if "name" not in command:
return "Error:", "Missing 'name' field in 'command' object"
command_name = command["name"]
# Use an empty dictionary if 'args' field is not present in 'command' object
arguments = command.get("args", {})
return command_name, arguments
except json.decoder.JSONDecodeError:
return "Error:", "Invalid JSON"
# All other errors, return "Error: + error message"
except Exception as e:
return "Error:", str(e)

View File

@ -0,0 +1,156 @@
import functools
import importlib
import inspect
from typing import Any, Callable, Optional
# Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
class Command:
"""A class representing a command.
Attributes:
name (str): The name of the command.
description (str): A brief description of what the command does.
signature (str): The signature of the function that the command executes. Defaults to None.
"""
def __init__(
self,
name: str,
description: str,
method: Callable[..., Any],
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
):
self.name = name
self.description = description
self.method = method
self.signature = signature if signature else str(inspect.signature(self.method))
self.enabled = enabled
self.disabled_reason = disabled_reason
def __call__(self, *args, **kwargs) -> Any:
if not self.enabled:
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
return self.method(*args, **kwargs)
def __str__(self) -> str:
return f"{self.name}: {self.description}, args: {self.signature}"
class CommandRegistry:
"""
The CommandRegistry class is a manager for a collection of Command objects.
It allows the registration, modification, and retrieval of Command objects,
as well as the scanning and loading of command plugins from a specified
directory.
"""
def __init__(self):
self.commands = {}
def _import_module(self, module_name: str) -> Any:
return importlib.import_module(module_name)
def _reload_module(self, module: Any) -> Any:
return importlib.reload(module)
def register(self, cmd: Command) -> None:
self.commands[cmd.name] = cmd
def unregister(self, command_name: str):
if command_name in self.commands:
del self.commands[command_name]
else:
raise KeyError(f"Command '{command_name}' not found in registry.")
def reload_commands(self) -> None:
"""Reloads all loaded command plugins."""
for cmd_name in self.commands:
cmd = self.commands[cmd_name]
module = self._import_module(cmd.__module__)
reloaded_module = self._reload_module(module)
if hasattr(reloaded_module, "register"):
reloaded_module.register(self)
def get_command(self, name: str) -> Callable[..., Any]:
return self.commands[name]
def call(self, command_name: str, **kwargs) -> Any:
if command_name not in self.commands:
raise KeyError(f"Command '{command_name}' not found in registry.")
command = self.commands[command_name]
return command(**kwargs)
def command_prompt(self) -> str:
"""
Returns a string representation of all registered `Command` objects for use in a prompt
"""
commands_list = [
f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values())
]
return "\n".join(commands_list)
def import_commands(self, module_name: str) -> None:
"""
Imports the specified Python module containing command plugins.
This method imports the associated module and registers any functions or
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
as `Command` objects. The registered `Command` objects are then added to the
`commands` dictionary of the `CommandRegistry` object.
Args:
module_name (str): The name of the module to import for command plugins.
"""
module = importlib.import_module(module_name)
for attr_name in dir(module):
attr = getattr(module, attr_name)
# Register decorated functions
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
attr, AUTO_GPT_COMMAND_IDENTIFIER
):
self.register(attr.command)
# Register command classes
elif (
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
):
cmd_instance = attr()
self.register(cmd_instance)
def command(
name: str,
description: str,
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
) -> Callable[..., Any]:
"""The command decorator is used to create Command objects from ordinary functions."""
def decorator(func: Callable[..., Any]) -> Command:
cmd = Command(
name=name,
description=description,
method=func,
signature=signature,
enabled=enabled,
disabled_reason=disabled_reason,
)
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
return func(*args, **kwargs)
wrapper.command = cmd
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
return wrapper
return decorator

View File

@ -0,0 +1,296 @@
from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
import pandas as pd
import uuid
import os
import matplotlib
import seaborn as sns
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
static_message_img_path = os.path.join(os.getcwd(), "message/img")
def data_pre_classification(df: DataFrame):
## Data pre-classification
columns = df.columns.tolist()
number_columns = []
non_numeric_colums = []
# 收集数据分类小于10个的列
non_numeric_colums_value_map = {}
numeric_colums_value_map = {}
for column_name in columns:
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
number_columns.append(column_name)
unique_values = df[column_name].unique()
numeric_colums_value_map.update({column_name: len(unique_values)})
else:
non_numeric_colums.append(column_name)
unique_values = df[column_name].unique()
non_numeric_colums_value_map.update({column_name: len(unique_values)})
sorted_numeric_colums_value_map = dict(
sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
)
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
sorted_colums_value_map = dict(
sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
)
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
# Analyze x-coordinate
if len(non_numeric_colums_sort_list) > 0:
x_cloumn = non_numeric_colums_sort_list[-1]
non_numeric_colums_sort_list.remove(x_cloumn)
else:
x_cloumn = number_columns[0]
numeric_colums_sort_list.remove(x_cloumn)
# Analyze y-coordinate
if len(numeric_colums_sort_list) > 0:
y_column = numeric_colums_sort_list[0]
numeric_colums_sort_list.remove(y_column)
else:
raise ValueError("Not enough numeric columns for chart")
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
def zh_font_set():
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
@command(
"response_line_chart",
"Line chart display, used to display comparative trend analysis data",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_line_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},")
if df.size <= 0:
raise ValueError("No Data")
# set font
# zh_font_set()
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {"font.sans-serif": can_use_fonts}
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
x, y, non_num_columns, num_colmns = data_pre_classification(df)
# ## 复杂折线图实现
if len(num_colmns) > 0:
num_colmns.append(y)
df_melted = pd.melt(
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
)
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
else:
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@command(
"response_bar_chart",
"Histogram, suitable for comparative analysis of multiple target values",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_bar_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},")
if df.size <= 0:
raise ValueError("No Data")
# set font
# zh_font_set()
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {"font.sans-serif": can_use_fonts}
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
hue = None
x, y, non_num_columns, num_colmns = data_pre_classification(df)
if len(non_num_columns) >= 1:
hue = non_num_columns[0]
if len(num_colmns) >= 1:
if hue:
if len(num_colmns) >= 2:
can_use_columns = num_colmns[:2]
else:
can_use_columns = num_colmns
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
for sub_y_column in can_use_columns:
sns.barplot(
data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
)
else:
if len(num_colmns) > 5:
can_use_columns = num_colmns[:5]
else:
can_use_columns = num_colmns
can_use_columns.append(y)
df_melted = pd.melt(
df,
id_vars=x,
value_vars=can_use_columns,
var_name="line",
value_name="Value",
)
sns.barplot(
data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax
)
else:
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
# 设置 y 轴刻度格式为普通数字格式
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@command(
"response_pie_chart",
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_pie_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_pie_chart:{speak},")
columns = df.columns.tolist()
if df.size <= 0:
raise ValueError("No Data")
# set font
# zh_font_set()
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
ax = df.plot(
kind="pie",
y=columns[1],
ax=ax,
labels=df[columns[0]].values,
startangle=90,
autopct="%1.1f%%",
)
plt.axis("equal") # 使饼图为正圆形
# plt.title(columns[0])
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img

View File

@ -0,0 +1,24 @@
from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command(
"response_table",
"Table display, suitable for display with many display columns or non-numeric columns",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_table(speak: str, df: DataFrame) -> str:
logger.info(f"response_table:{speak}")
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text

View File

@ -0,0 +1,39 @@
from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command(
"response_data_text",
"Text display, the default display method, suitable for single-line or simple content display",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_data_text(speak: str, df: DataFrame) -> str:
logger.info(f"response_data_text:{speak}")
data = df.values
row_size = data.shape[0]
value_str = ""
text_info = ""
if row_size > 1:
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
elif row_size == 1:
row = data[0]
for value in row:
if value_str:
value_str = value_str + f", ** {value} **"
else:
value_str = f" ** {value} **"
text_info = f"{speak}: {value_str}"
else:
text_info = f"##### {speak}: _没有找到可用的数据_"
return text_info

View File

@ -0,0 +1,4 @@
class NotCommands(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@ -0,0 +1,158 @@
""" A module for generating custom prompt strings."""
import json
from typing import Any, Callable, Dict, List, Optional
class PluginPromptGenerator:
"""
A class for generating custom prompt strings based on constraints, commands,
resources, and performance evaluations.
"""
def __init__(self) -> None:
"""
Initialize the PromptGenerator object with empty lists of constraints,
commands, resources, and performance evaluations.
"""
self.constraints = []
self.commands = []
self.resources = []
self.performance_evaluation = []
self.goals = []
self.command_registry = None
self.name = "Bob"
self.role = "AI"
self.response_format = {
"thoughts": {
"text": "thought",
"reasoning": "reasoning",
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user",
},
"command": {"name": "command name", "args": {"arg name": "value"}},
}
def add_constraint(self, constraint: str) -> None:
"""
Add a constraint to the constraints list.
Args:
constraint (str): The constraint to be added.
"""
self.constraints.append(constraint)
def add_command(
self,
command_label: str,
command_name: str,
args=None,
function: Optional[Callable] = None,
) -> None:
"""
Add a command to the commands list with a label, name, and optional arguments.
Args:
command_label (str): The label of the command.
command_name (str): The name of the command.
args (dict, optional): A dictionary containing argument names and their
values. Defaults to None.
function (callable, optional): A callable function to be called when
the command is executed. Defaults to None.
"""
if args is None:
args = {}
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
command = {
"label": command_label,
"name": command_name,
"args": command_args,
"function": function,
}
self.commands.append(command)
def _generate_command_string(self, command: Dict[str, Any]) -> str:
"""
Generate a formatted string representation of a command.
Args:
command (dict): A dictionary containing command information.
Returns:
str: The formatted command string.
"""
args_string = ", ".join(
f'"{key}": "{value}"' for key, value in command["args"].items()
)
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
def add_resource(self, resource: str) -> None:
"""
Add a resource to the resources list.
Args:
resource (str): The resource to be added.
"""
self.resources.append(resource)
def add_performance_evaluation(self, evaluation: str) -> None:
"""
Add a performance evaluation item to the performance_evaluation list.
Args:
evaluation (str): The evaluation item to be added.
"""
self.performance_evaluation.append(evaluation)
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
"""
Generate a numbered list from given items based on the item_type.
Args:
items (list): A list of items to be numbered.
item_type (str, optional): The type of items in the list.
Defaults to 'list'.
Returns:
str: The formatted numbered list.
"""
if item_type == "command":
command_strings = []
if self.command_registry:
command_strings += [
str(item)
for item in self.command_registry.commands.values()
if item.enabled
]
# terminate command is added manually
command_strings += [self._generate_command_string(item) for item in items]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
else:
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
def generate_commands_string(self) -> str:
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
def generate_prompt_string(self) -> str:
"""
Generate a prompt string based on the constraints, commands, resources,
and performance evaluations.
Returns:
str: The generated prompt string.
"""
formatted_response_format = json.dumps(self.response_format, indent=4)
return (
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
"Commands:\n"
f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n"
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
"Performance Evaluation:\n"
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
"You should only respond in JSON format as described below and ensure the"
"response can be parsed by Python json.loads \nResponse"
f" Format: \n{formatted_response_format}"
)

View File

@ -0,0 +1,10 @@
from datetime import datetime
def get_datetime() -> str:
"""Return the current date and time
Returns:
str: The current date and time
"""
return "Current date and time: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@ -0,0 +1,26 @@
import json
import time
from fastapi import (
APIRouter,
Body,
)
from typing import List
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.openapi.api_view_model import (
Result,
)
router = APIRouter()
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
@router.get("/v1/mange/agent/list", response_model=Result[str])
async def get_agent_list():
logger.info(f"get_agent_list!")
return Result.succ(None)

View File

@ -0,0 +1,136 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func
from sqlalchemy import UniqueConstraint
from pilot.base_modules.meta_data.meta_data import Base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
class MyPluginEntity(Base):
__tablename__ = 'my_plugin'
id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String, nullable=True, comment="user's tenant")
user_code = Column(String, nullable=True, comment="user code")
user_name = Column(String, nullable=True, comment="user name")
name = Column(String, unique=True, nullable=False, comment="plugin name")
type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
__table_args__ = (
UniqueConstraint('name', name="uk_name"),
)
class MyPluginDao(BaseDao[MyPluginEntity]):
def __init__(self):
super().__init__(
database="dbgpt", orm_base=Base, db_engine =engine , session= session
)
def add(self, engity: MyPluginEntity):
session = self.Session()
my_plugin = MyPluginEntity(
tenant=engity.tenant,
user_code=engity.user_code,
user_name=engity.user_name,
name=engity.name,
type=engity.type,
version=engity.version,
use_count=engity.use_count or 0,
succ_count=engity.succ_count or 0,
created_at=datetime.now(),
)
session.add(my_plugin)
session.commit()
id = my_plugin.id
session.close()
return id
def update(self, entity: MyPluginEntity):
session = self.Session()
updated = session.merge(entity)
session.commit()
return updated.id
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
session = self.Session()
my_plugins = session.query(MyPluginEntity)
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.name == query.name
)
if query.tenant is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.tenant == query.tenant
)
if query.type is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.type == query.type
)
if query.user_code is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.user_code == query.user_code
)
if query.user_name is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.user_name == query.user_name
)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
result = my_plugins.all()
session.close()
return result
def count(self, query: MyPluginEntity):
session = self.Session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.name == query.name
)
if query.type is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.type == query.type
)
if query.tenant is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.tenant == query.tenant
)
if query.user_code is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.user_code == query.user_code
)
if query.user_name is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.user_name == query.user_name
)
count = my_plugins.scalar()
session.close()
return count
def delete(self, plugin_id: int):
session = self.Session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id)
my_plugins = session.query(MyPluginEntity)
if query.id is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.id == query.id
)
my_plugins.delete()
session.commit()
session.close()

View File

@ -0,0 +1,136 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean
from sqlalchemy import UniqueConstraint
from pilot.base_modules.meta_data.meta_data import Base
from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
class PluginHubEntity(Base):
__tablename__ = 'plugin_hub'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
name = Column(String, unique=True, nullable=False, comment="plugin name")
author = Column(String, nullable=True, comment="plugin author")
email = Column(String, nullable=True, comment="plugin author email")
type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version")
storage_channel = Column(String, comment="plugin storage channel")
storage_url = Column(String, comment="plugin download url")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
installed = Column(Boolean, default=False, comment="plugin already installed")
__table_args__ = (
UniqueConstraint('name', name="uk_name"),
Index('idx_q_type', 'type'),
)
class PluginHubDao(BaseDao[PluginHubEntity]):
def __init__(self):
super().__init__(
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def add(self, engity: PluginHubEntity):
session = self.Session()
plugin_hub = PluginHubEntity(
name=engity.name,
author=engity.author,
email=engity.email,
type=engity.type,
version=engity.version,
storage_channel=engity.storage_channel,
storage_url=engity.storage_url,
created_at=datetime.now(),
)
session.add(plugin_hub)
session.commit()
id = plugin_hub.id
session.close()
return id
def update(self, entity: PluginHubEntity):
session = self.Session()
updated = session.merge(entity)
session.commit()
return updated.id
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
session = self.Session()
plugin_hubs = session.query(PluginHubEntity)
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.name == query.name
)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.type == query.type
)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.author == query.author
)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
)
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
result = plugin_hubs.all()
session.close()
return result
def get_by_name(self, name: str) -> PluginHubEntity:
session = self.Session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.name == name
)
result = plugin_hubs.get(1)
session.close()
return result
def count(self, query: PluginHubEntity):
session = self.Session()
plugin_hubs = session.query(func.count(PluginHubEntity.id))
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.name == query.name
)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.type == query.type
)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.author == query.author
)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
)
count = plugin_hubs.scalar()
session.close()
return count
def delete(self, plugin_id: int):
session = self.Session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = PluginHubEntity(id=plugin_id)
plugin_hubs = session.query(PluginHubEntity)
if query.id is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.id == query.id
)
plugin_hubs.delete()
session.commit()
session.close()

View File

View File

@ -0,0 +1,69 @@
import logging
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from .schema import PluginStorageType
logger = logging.getLogger("agent_hub")
class AgentHub:
def __init__(self) -> None:
self.hub_dao = PluginHubDao()
self.my_lugin_dao = MyPluginDao()
def install_plugin(self, plugin_name: str, user_name: str = None):
logger.info(f"install_plugin {plugin_name}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
if plugin_entity:
if plugin_entity.storage_channel == PluginStorageType.Git.value:
try:
self.__download_from_git(plugin_name, plugin_entity.storage_url)
self.load_plugin(plugin_name)
# add to my plugins and edit hub status
plugin_entity.installed = True
my_plugin_entity = self.__build_my_plugin(plugin_entity)
if not user_name:
# TODO use user
my_plugin_entity.user_code = ""
my_plugin_entity.user_name = user_name
my_plugin_entity.tenant = ""
with self.hub_dao.Session() as session:
try:
session.add(my_plugin_entity)
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
except Exception as e:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
else:
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
else:
raise ValueError(f"Can't Find Plugin {plugin_name}!")
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
my_plugin_entity = MyPluginEntity()
my_plugin_entity.name = hub_plugin.name
my_plugin_entity.type = hub_plugin.type
my_plugin_entity.version = hub_plugin.version
return my_plugin_entity
def __download_from_git(self, plugin_name, url):
pass
def load_plugin(self, plugin_name):
pass
def get_my_plugin(self, user: str):
pass
def uninstall_plugin(self):
pass

View File

@ -0,0 +1,6 @@
from enum import Enum
class PluginStorageType(Enum):
Git = "git"
Oss = "oss"

View File

@ -0,0 +1,182 @@
"""加载组件"""
import json
import os
import glob
import zipfile
import requests
import threading
import datetime
from pathlib import Path
from typing import List
from urllib.parse import urlparse
from zipimport import zipimporter
import requests
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
"""
Loader zip plugin file. Native support Auto_gpt_plugin
Args:
zip_path (str): Path to the zipfile.
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
list[str]: The list of module names found or empty list if none were found.
"""
result = []
with zipfile.ZipFile(zip_path, "r") as zfile:
for name in zfile.namelist():
if name.endswith("__init__.py") and not name.startswith("__MACOSX"):
logger.debug(f"Found module '{name}' in the zipfile at: {name}")
result.append(name)
if len(result) == 0:
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
return result
def write_dict_to_json_file(data: dict, file_path: str) -> None:
"""
Write a dictionary to a JSON file.
Args:
data (dict): Dictionary to write.
file_path (str): Path to the file.
"""
with open(file_path, "w") as file:
json.dump(data, file, indent=4)
def create_directory_if_not_exists(directory_path: str) -> bool:
"""
Create a directory if it does not exist.
Args:
directory_path (str): Path to the directory.
Returns:
bool: True if the directory was created, else False.
"""
if not os.path.exists(directory_path):
try:
os.makedirs(directory_path)
logger.debug(f"Created directory: {directory_path}")
return True
except OSError as e:
logger.warn(f"Error creating directory {directory_path}: {e}")
return False
else:
logger.info(f"Directory {directory_path} already exists")
return True
def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load:
print("not auto load_native_plugins")
return
def load_from_git(cfg: Config):
print("async load_native_plugins")
branch_name = cfg.plugins_git_branch
native_plugin_repo = "DB-GPT-Plugins"
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try:
session = requests.Session()
response = session.get(
url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
)
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name)
with open(file_name, "wb") as f:
f.write(response.content)
print("save file")
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
else:
print("get file faildresponse code", response.status_code)
except Exception as e:
print("load plugin from git exception!" + str(e))
t = threading.Thread(target=load_from_git, args=(cfg,))
t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.
Args:
cfg (Config): Config instance including plugins config
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
List[Tuple[str, Path]]: List of plugins.
"""
loaded_plugins = []
current_dir = os.getcwd()
print(current_dir)
# Generic plugins
plugins_path_path = Path(PLUGINS_DIR)
for plugin in plugins_path_path.glob("*.zip"):
if moduleList := inspect_zip_for_modules(str(plugin), debug):
for module in moduleList:
plugin = Path(plugin)
module = Path(module)
logger.debug(f"Plugin: {plugin} Module: {module}")
zipped_package = zipimporter(str(plugin))
zipped_module = zipped_package.load_module(str(module.parent))
for key in dir(zipped_module):
if key.startswith("__"):
continue
a_module = getattr(zipped_module, key)
a_keys = dir(a_module)
if (
"_abc_impl" in a_keys
and a_module.__name__ != "AutoGPTPluginTemplate"
# and denylist_allowlist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
if loaded_plugins:
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
for plugin in loaded_plugins:
logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}")
return loaded_plugins
def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
"""Check if the plugin is in the allowlist or denylist.
Args:
plugin_name (str): Name of the plugin.
cfg (Config): Config object.
Returns:
True or False
"""
logger.debug(f"Checking if plugin {plugin_name} should be loaded")
if plugin_name in cfg.plugins_denylist:
logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.")
return False
if plugin_name in cfg.plugins_allowlist:
logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.")
return True
ack = input(
f"WARNING: Plugin {plugin_name} found. But not in the"
f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): "
)
return ack.lower() == cfg.authorise_key

View File

@ -0,0 +1,2 @@
flask_sqlalchemy==3.0.5
flask==2.3.2

View File

View File

@ -0,0 +1,7 @@
class ModuleMangeApi:
def module_name(self):
pass
def register(self):
pass

View File

View File

@ -0,0 +1,21 @@
from typing import TypeVar, Generic, List, Any
from sqlalchemy.orm import sessionmaker
T = TypeVar('T')
class BaseDao(Generic[T]):
def __init__(
self, orm_base=None, database: str = None, db_engine: Any = None, session: Any = None,
) -> None:
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
self._orm_base = orm_base
self._database = database
self._db_engine = db_engine
self._session = session
@property
def Session(self):
if not self._session:
self._session = sessionmaker(bind=self.db_engine)
return self._session

View File

@ -0,0 +1,109 @@
import uuid
import os
import duckdb
import sqlite3
from datetime import datetime
from typing import Optional, Type, TypeVar
import sqlalchemy as sa
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate,upgrade
from flask.cli import with_appcontext
import subprocess
from sqlalchemy import create_engine,DateTime, String, func, MetaData
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from alembic import context, command
from alembic.config import Config
default_db_path = os.path.join(os.getcwd(), "meta_data")
os.makedirs(default_db_path, exist_ok=True)
db_path = default_db_path + "/dbgpt.db"
connection = sqlite3.connect(db_path)
engine = create_engine(f'sqlite:///{db_path}')
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
Base = declarative_base(bind=engine)
# Base.metadata.create_all()
# 创建Alembic配置对象
alembic_ini_path = default_db_path + "/alembic.ini"
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
os.makedirs(default_db_path + "/alembic", exist_ok=True)
alembic_cfg.set_main_option('script_location', default_db_path + "/alembic")
# 将模型和会话传递给Alembic配置
alembic_cfg.attributes['target_metadata'] = Base.metadata
alembic_cfg.attributes['session'] = session
# # 创建表
# Base.metadata.create_all(engine)
#
# # 删除表
# Base.metadata.drop_all(engine)
# app = Flask(__name__)
# default_db_path = os.path.join(os.getcwd(), "meta_data")
# duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/dbgpt.db")
# app.config['SQLALCHEMY_DATABASE_URI'] = f'duckdb://{duckdb_path}'
# app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
# db = SQLAlchemy(app)
# migrate = Migrate(app, db)
#
# # 设置FLASK_APP环境变量
# import os
# os.environ['FLASK_APP'] = 'server.dbgpt_server.py'
#
# @app.cli.command("db_init")
# @with_appcontext
# def db_init():
# subprocess.run(["flask", "db", "init"])
#
# @app.cli.command("db_migrate")
# @with_appcontext
# def db_migrate():
# subprocess.run(["flask", "db", "migrate"])
#
# @app.cli.command("db_upgrade")
# @with_appcontext
# def db_upgrade():
# subprocess.run(["flask", "db", "upgrade"])
#
def ddl_init_and_upgrade():
# Base.metadata.create_all(bind=engine)
# 生成并应用迁移脚本
# command.upgrade(alembic_cfg, 'head')
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
with engine.connect() as connection:
alembic_cfg.attributes['connection'] = connection
command.revision(alembic_cfg, "test", True)
command.upgrade(alembic_cfg, "head")
# alembic_cfg.attributes['connection'] = engine.connect()
# command.upgrade(alembic_cfg, 'head')
# with app.app_context():
# db_init()
# db_migrate()
# db_upgrade()

View File

@ -0,0 +1 @@
alembic==1.12.0

View File