mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 01:12:15 +00:00
fix(ChatExcel): ChatExcel OutParse Bug Fix
1.ChatExcel OutParse Bug Fix
This commit is contained in:
parent
cad2785d94
commit
af68e9c4c0
2
.gitignore
vendored
2
.gitignore
vendored
@ -150,3 +150,5 @@ pilot/mock_datas/db-gpt-test.db.wal
|
|||||||
logswebserver.log.*
|
logswebserver.log.*
|
||||||
.history/*
|
.history/*
|
||||||
.plugin_env
|
.plugin_env
|
||||||
|
/pilot/meta_data/alembic/versions/*
|
||||||
|
/pilot/meta_data/*.db
|
@ -1,2 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
@ -1,12 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
|
|
||||||
|
|
||||||
class Agent:
|
|
||||||
"""Agent class for interacting with DB-GPT
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
@ -1,83 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding:utf-8 -*-
|
|
||||||
"""Agent manager for managing GPT agents"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.model.base import Message
|
|
||||||
from pilot.singleton import Singleton
|
|
||||||
|
|
||||||
|
|
||||||
class AgentManager(metaclass=Singleton):
|
|
||||||
"""Agent manager for managing GPT agents"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.next_key = 0
|
|
||||||
self.agents = {} # key, (task, full_message_history, model)
|
|
||||||
self.cfg = Config()
|
|
||||||
|
|
||||||
"""Agent manager for managing DB-GPT agents
|
|
||||||
In order to compatible auto gpt plugins,
|
|
||||||
we use the same template with it.
|
|
||||||
|
|
||||||
Args: next_keys
|
|
||||||
agents
|
|
||||||
cfg
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.next_key = 0
|
|
||||||
self.agents = {} # TODO need to define
|
|
||||||
self.cfg = Config()
|
|
||||||
|
|
||||||
# Create new GPT agent
|
|
||||||
# TODO: Centralise use of create_chat_completion() to globally enforce token limit
|
|
||||||
|
|
||||||
def create_agent(self, task: str, prompt: str, model: str) -> tuple[int, str]:
|
|
||||||
"""Create a new agent and return its key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The task to perform
|
|
||||||
prompt: The prompt to use
|
|
||||||
model: The model to use
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The key of the new agent
|
|
||||||
"""
|
|
||||||
|
|
||||||
def message_agent(self, key: str | int, message: str) -> str:
|
|
||||||
"""Send a message to an agent and return its response
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The key of the agent to message
|
|
||||||
message: The message to send to the agent
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The agent's response
|
|
||||||
"""
|
|
||||||
|
|
||||||
def list_agents(self) -> list[tuple[str | int, str]]:
|
|
||||||
"""Return a list of all agents
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of tuples of the form (key, task)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Return a list of agent keys and their tasks
|
|
||||||
return [(key, task) for key, (task, _, _) in self.agents.items()]
|
|
||||||
|
|
||||||
def delete_agent(self, key: str | int) -> bool:
|
|
||||||
"""Delete an agent from the agent manager
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The key of the agent to delete
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
del self.agents[int(key)]
|
|
||||||
return True
|
|
||||||
except KeyError:
|
|
||||||
return False
|
|
@ -1,122 +0,0 @@
|
|||||||
import contextlib
|
|
||||||
import json
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from colorama import Fore
|
|
||||||
from regex import regex
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.json_utils.json_fix_general import (
|
|
||||||
add_quotes_to_property_names,
|
|
||||||
balance_braces,
|
|
||||||
fix_invalid_escape,
|
|
||||||
)
|
|
||||||
from pilot.logs import logger
|
|
||||||
|
|
||||||
|
|
||||||
CFG = Config()
|
|
||||||
|
|
||||||
|
|
||||||
def fix_and_parse_json(
|
|
||||||
json_to_load: str, try_to_fix_with_gpt: bool = True
|
|
||||||
) -> Dict[Any, Any]:
|
|
||||||
"""Fix and parse JSON string
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_to_load (str): The JSON string.
|
|
||||||
try_to_fix_with_gpt (bool, optional): Try to fix the JSON with GPT.
|
|
||||||
Defaults to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str or dict[Any, Any]: The parsed JSON.
|
|
||||||
"""
|
|
||||||
|
|
||||||
with contextlib.suppress(json.JSONDecodeError):
|
|
||||||
json_to_load = json_to_load.replace("\t", "")
|
|
||||||
return json.loads(json_to_load)
|
|
||||||
|
|
||||||
with contextlib.suppress(json.JSONDecodeError):
|
|
||||||
json_to_load = correct_json(json_to_load)
|
|
||||||
return json.loads(json_to_load)
|
|
||||||
# Let's do something manually:
|
|
||||||
# sometimes GPT responds with something BEFORE the braces:
|
|
||||||
# "I'm sorry, I don't understand. Please try again."
|
|
||||||
# {"text": "I'm sorry, I don't understand. Please try again.",
|
|
||||||
# "confidence": 0.0}
|
|
||||||
# So let's try to find the first brace and then parse the rest
|
|
||||||
# of the string
|
|
||||||
try:
|
|
||||||
brace_index = json_to_load.index("{")
|
|
||||||
maybe_fixed_json = json_to_load[brace_index:]
|
|
||||||
last_brace_index = maybe_fixed_json.rindex("}")
|
|
||||||
maybe_fixed_json = maybe_fixed_json[: last_brace_index + 1]
|
|
||||||
return json.loads(maybe_fixed_json)
|
|
||||||
except (json.JSONDecodeError, ValueError) as e:
|
|
||||||
logger.error("参数解析错误", e)
|
|
||||||
|
|
||||||
|
|
||||||
def correct_json(json_to_load: str) -> str:
|
|
||||||
"""
|
|
||||||
Correct common JSON errors.
|
|
||||||
Args:
|
|
||||||
json_to_load (str): The JSON string.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.debug("json", json_to_load)
|
|
||||||
json.loads(json_to_load)
|
|
||||||
return json_to_load
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.debug("json loads error", e)
|
|
||||||
error_message = str(e)
|
|
||||||
if error_message.startswith("Invalid \\escape"):
|
|
||||||
json_to_load = fix_invalid_escape(json_to_load, error_message)
|
|
||||||
if error_message.startswith(
|
|
||||||
"Expecting property name enclosed in double quotes"
|
|
||||||
):
|
|
||||||
json_to_load = add_quotes_to_property_names(json_to_load)
|
|
||||||
try:
|
|
||||||
json.loads(json_to_load)
|
|
||||||
return json_to_load
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.debug("json loads error - add quotes", e)
|
|
||||||
error_message = str(e)
|
|
||||||
if balanced_str := balance_braces(json_to_load):
|
|
||||||
return balanced_str
|
|
||||||
return json_to_load
|
|
||||||
|
|
||||||
|
|
||||||
def attempt_to_fix_json_by_finding_outermost_brackets(json_string: str):
|
|
||||||
from pilot.speech.say import say_text
|
|
||||||
|
|
||||||
if CFG.speak_mode and CFG.debug_mode:
|
|
||||||
say_text(
|
|
||||||
"I have received an invalid JSON response from the OpenAI API. "
|
|
||||||
"Trying to fix it now."
|
|
||||||
)
|
|
||||||
logger.error("Attempting to fix JSON by finding outermost brackets\n")
|
|
||||||
|
|
||||||
try:
|
|
||||||
json_pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}")
|
|
||||||
json_match = json_pattern.search(json_string)
|
|
||||||
|
|
||||||
if json_match:
|
|
||||||
# Extract the valid JSON object from the string
|
|
||||||
json_string = json_match.group(0)
|
|
||||||
logger.typewriter_log(
|
|
||||||
title="Apparently json was fixed.", title_color=Fore.GREEN
|
|
||||||
)
|
|
||||||
if CFG.speak_mode and CFG.debug_mode:
|
|
||||||
say_text("Apparently json was fixed.")
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
if CFG.debug_mode:
|
|
||||||
logger.error(f"Error: Invalid JSON: {json_string}\n")
|
|
||||||
if CFG.speak_mode:
|
|
||||||
say_text("Didn't work. I will have to ignore this response then.")
|
|
||||||
logger.error("Error: Invalid JSON, setting it to empty JSON now.\n")
|
|
||||||
json_string = {}
|
|
||||||
|
|
||||||
return fix_and_parse_json(json_string)
|
|
0
pilot/base_modules/__init__.py
Normal file
0
pilot/base_modules/__init__.py
Normal file
6
pilot/base_modules/agent/__init__.py
Normal file
6
pilot/base_modules/agent/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
|
||||||
|
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||||
|
|
||||||
|
from .commands.command import execute_command, get_command
|
||||||
|
from .commands.generator import PluginPromptGenerator
|
||||||
|
from .commands.disply_type.show_chart_gen import static_message_img_path
|
13
pilot/base_modules/agent/agent.py
Normal file
13
pilot/base_modules/agent/agent.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFacade(ABC):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
|
||||||
|
|
0
pilot/base_modules/agent/commands/__init__.py
Normal file
0
pilot/base_modules/agent/commands/__init__.py
Normal file
61
pilot/base_modules/agent/commands/built_in/audio_text.py
Normal file
61
pilot/base_modules/agent/commands/built_in/audio_text.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""Commands for converting audio to text."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from pilot.base_modules.agent.commands.command_mange import command
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
@command(
|
||||||
|
"read_audio_from_file",
|
||||||
|
"Convert Audio to text",
|
||||||
|
'"filename": "<filename>"',
|
||||||
|
CFG.huggingface_audio_to_text_model,
|
||||||
|
"Configure huggingface_audio_to_text_model.",
|
||||||
|
)
|
||||||
|
def read_audio_from_file(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert audio to text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The path to the audio file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The text from the audio
|
||||||
|
"""
|
||||||
|
with open(filename, "rb") as audio_file:
|
||||||
|
audio = audio_file.read()
|
||||||
|
return read_audio(audio)
|
||||||
|
|
||||||
|
|
||||||
|
def read_audio(audio: bytes) -> str:
|
||||||
|
"""
|
||||||
|
Convert audio to text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (bytes): The audio to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The text from the audio
|
||||||
|
"""
|
||||||
|
model = CFG.huggingface_audio_to_text_model
|
||||||
|
api_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
|
api_token = CFG.huggingface_api_token
|
||||||
|
headers = {"Authorization": f"Bearer {api_token}"}
|
||||||
|
|
||||||
|
if api_token is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You need to set your Hugging Face API token in the config file."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
api_url,
|
||||||
|
headers=headers,
|
||||||
|
data=audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = json.loads(response.content.decode("utf-8"))["text"]
|
||||||
|
return f"The audio says: {text}"
|
123
pilot/base_modules/agent/commands/built_in/image_gen.py
Normal file
123
pilot/base_modules/agent/commands/built_in/image_gen.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
""" Image Generation Module for AutoGPT."""
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
from base64 import b64decode
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from pilot.base_modules.agent.commands.command_mange import command
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.logs import logger
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
@command("generate_image", "Generate Image", '"prompt": "<prompt>"', CFG.image_provider)
|
||||||
|
def generate_image(prompt: str, size: int = 256) -> str:
|
||||||
|
"""Generate an image from a prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to use
|
||||||
|
size (int, optional): The size of the image. Defaults to 256. (Not supported by HuggingFace)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The filename of the image
|
||||||
|
"""
|
||||||
|
filename = f"{CFG.workspace_path}/{str(uuid.uuid4())}.jpg"
|
||||||
|
|
||||||
|
# HuggingFace
|
||||||
|
if CFG.image_provider == "huggingface":
|
||||||
|
return generate_image_with_hf(prompt, filename)
|
||||||
|
# SD WebUI
|
||||||
|
elif CFG.image_provider == "sdwebui":
|
||||||
|
return generate_image_with_sd_webui(prompt, filename, size)
|
||||||
|
return "No Image Provider Set"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_with_hf(prompt: str, filename: str) -> str:
|
||||||
|
"""Generate an image with HuggingFace's API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to use
|
||||||
|
filename (str): The filename to save the image to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The filename of the image
|
||||||
|
"""
|
||||||
|
API_URL = (
|
||||||
|
f"https://api-inference.huggingface.co/models/{CFG.huggingface_image_model}"
|
||||||
|
)
|
||||||
|
if CFG.huggingface_api_token is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You need to set your Hugging Face API token in the config file."
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {CFG.huggingface_api_token}",
|
||||||
|
"X-Use-Cache": "false",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
API_URL,
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"inputs": prompt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
image = Image.open(io.BytesIO(response.content))
|
||||||
|
logger.info(f"Image Generated for prompt:{prompt}")
|
||||||
|
|
||||||
|
image.save(filename)
|
||||||
|
|
||||||
|
return f"Saved to disk:{filename}"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_with_sd_webui(
|
||||||
|
prompt: str,
|
||||||
|
filename: str,
|
||||||
|
size: int = 512,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
extra: dict = {},
|
||||||
|
) -> str:
|
||||||
|
"""Generate an image with Stable Diffusion webui.
|
||||||
|
Args:
|
||||||
|
prompt (str): The prompt to use
|
||||||
|
filename (str): The filename to save the image to
|
||||||
|
size (int, optional): The size of the image. Defaults to 256.
|
||||||
|
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||||
|
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||||
|
Returns:
|
||||||
|
str: The filename of the image
|
||||||
|
"""
|
||||||
|
# Create a session and set the basic auth if needed
|
||||||
|
s = requests.Session()
|
||||||
|
if CFG.sd_webui_auth:
|
||||||
|
username, password = CFG.sd_webui_auth.split(":")
|
||||||
|
s.auth = (username, password or "")
|
||||||
|
|
||||||
|
# Generate the images
|
||||||
|
response = requests.post(
|
||||||
|
f"{CFG.sd_webui_url}/sdapi/v1/txt2img",
|
||||||
|
json={
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"sampler_index": "DDIM",
|
||||||
|
"steps": 20,
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
"width": size,
|
||||||
|
"height": size,
|
||||||
|
"n_iter": 1,
|
||||||
|
**extra,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Image Generated for prompt:{prompt}")
|
||||||
|
|
||||||
|
# Save the image to disk
|
||||||
|
response = response.json()
|
||||||
|
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||||
|
image = Image.open(io.BytesIO(b64))
|
||||||
|
image.save(filename)
|
||||||
|
|
||||||
|
return f"Saved to disk:{filename}"
|
152
pilot/base_modules/agent/commands/command.py
Normal file
152
pilot/base_modules/agent/commands/command.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from .exception_not_commands import NotCommands
|
||||||
|
from .generator import PluginPromptGenerator
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
def _resolve_pathlike_command_args(command_args):
|
||||||
|
if "directory" in command_args and command_args["directory"] in {"", "/"}:
|
||||||
|
# todo
|
||||||
|
command_args["directory"] = ""
|
||||||
|
else:
|
||||||
|
for pathlike in ["filename", "directory", "clone_path"]:
|
||||||
|
if pathlike in command_args:
|
||||||
|
# todo
|
||||||
|
command_args[pathlike] = ""
|
||||||
|
return command_args
|
||||||
|
|
||||||
|
|
||||||
|
def execute_ai_response_json(
|
||||||
|
prompt: PluginPromptGenerator,
|
||||||
|
ai_response,
|
||||||
|
user_input: str = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command_registry:
|
||||||
|
ai_response:
|
||||||
|
prompt:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
from pilot.speech.say import say_text
|
||||||
|
|
||||||
|
cfg = Config()
|
||||||
|
|
||||||
|
command_name, arguments = get_command(ai_response)
|
||||||
|
|
||||||
|
if cfg.speak_mode:
|
||||||
|
say_text(f"I want to execute {command_name}")
|
||||||
|
|
||||||
|
arguments = _resolve_pathlike_command_args(arguments)
|
||||||
|
# Execute command
|
||||||
|
if command_name is not None and command_name.lower().startswith("error"):
|
||||||
|
result = f"Command {command_name} threw the following error: {arguments}"
|
||||||
|
elif command_name == "human_feedback":
|
||||||
|
result = f"Human feedback: {user_input}"
|
||||||
|
else:
|
||||||
|
for plugin in cfg.plugins:
|
||||||
|
if not plugin.can_handle_pre_command():
|
||||||
|
continue
|
||||||
|
command_name, arguments = plugin.pre_command(command_name, arguments)
|
||||||
|
command_result = execute_command(
|
||||||
|
command_name,
|
||||||
|
arguments,
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
result = f"{command_result}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def execute_command(
|
||||||
|
command_name: str,
|
||||||
|
arguments,
|
||||||
|
prompt: PluginPromptGenerator,
|
||||||
|
):
|
||||||
|
"""Execute the command and return the result
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command_name (str): The name of the command to execute
|
||||||
|
arguments (dict): The arguments for the command
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The result of the command
|
||||||
|
"""
|
||||||
|
|
||||||
|
cmd = prompt.command_registry.commands.get(command_name)
|
||||||
|
|
||||||
|
# If the command is found, call it with the provided arguments
|
||||||
|
if cmd:
|
||||||
|
try:
|
||||||
|
return cmd(**arguments)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {str(e)}"
|
||||||
|
# TODO: Change these to take in a file rather than pasted code, if
|
||||||
|
# non-file is given, return instructions "Input should be a python
|
||||||
|
# filepath, write your code to file and try again
|
||||||
|
else:
|
||||||
|
for command in prompt.commands:
|
||||||
|
if (
|
||||||
|
command_name == command["label"].lower()
|
||||||
|
or command_name == command["name"].lower()
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# 删除非定义参数
|
||||||
|
diff_ags = list(
|
||||||
|
set(arguments.keys()).difference(set(command["args"].keys()))
|
||||||
|
)
|
||||||
|
for arg_name in diff_ags:
|
||||||
|
del arguments[arg_name]
|
||||||
|
print(str(arguments))
|
||||||
|
return command["function"](**arguments)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {str(e)}"
|
||||||
|
raise NotCommands("非可用命令" + command_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_command(response_json: Dict):
|
||||||
|
"""Parse the response and return the command name and arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_json (json): The response from the AI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: The command name and arguments
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||||
|
|
||||||
|
Exception: If any other error occurs
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if "command" not in response_json:
|
||||||
|
return "Error:", "Missing 'command' object in JSON"
|
||||||
|
|
||||||
|
if not isinstance(response_json, dict):
|
||||||
|
return "Error:", f"'response_json' object is not dictionary {response_json}"
|
||||||
|
|
||||||
|
command = response_json["command"]
|
||||||
|
if not isinstance(command, dict):
|
||||||
|
return "Error:", "'command' object is not a dictionary"
|
||||||
|
|
||||||
|
if "name" not in command:
|
||||||
|
return "Error:", "Missing 'name' field in 'command' object"
|
||||||
|
|
||||||
|
command_name = command["name"]
|
||||||
|
|
||||||
|
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||||
|
arguments = command.get("args", {})
|
||||||
|
|
||||||
|
return command_name, arguments
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
return "Error:", "Invalid JSON"
|
||||||
|
# All other errors, return "Error: + error message"
|
||||||
|
except Exception as e:
|
||||||
|
return "Error:", str(e)
|
156
pilot/base_modules/agent/commands/command_mange.py
Normal file
156
pilot/base_modules/agent/commands/command_mange.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
# Unique identifier for auto-gpt commands
|
||||||
|
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
|
||||||
|
|
||||||
|
|
||||||
|
class Command:
|
||||||
|
"""A class representing a command.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The name of the command.
|
||||||
|
description (str): A brief description of what the command does.
|
||||||
|
signature (str): The signature of the function that the command executes. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
method: Callable[..., Any],
|
||||||
|
signature: str = "",
|
||||||
|
enabled: bool = True,
|
||||||
|
disabled_reason: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.method = method
|
||||||
|
self.signature = signature if signature else str(inspect.signature(self.method))
|
||||||
|
self.enabled = enabled
|
||||||
|
self.disabled_reason = disabled_reason
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
if not self.enabled:
|
||||||
|
return f"Command '{self.name}' is disabled: {self.disabled_reason}"
|
||||||
|
return self.method(*args, **kwargs)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.name}: {self.description}, args: {self.signature}"
|
||||||
|
|
||||||
|
|
||||||
|
class CommandRegistry:
|
||||||
|
"""
|
||||||
|
The CommandRegistry class is a manager for a collection of Command objects.
|
||||||
|
It allows the registration, modification, and retrieval of Command objects,
|
||||||
|
as well as the scanning and loading of command plugins from a specified
|
||||||
|
directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.commands = {}
|
||||||
|
|
||||||
|
def _import_module(self, module_name: str) -> Any:
|
||||||
|
return importlib.import_module(module_name)
|
||||||
|
|
||||||
|
def _reload_module(self, module: Any) -> Any:
|
||||||
|
return importlib.reload(module)
|
||||||
|
|
||||||
|
def register(self, cmd: Command) -> None:
|
||||||
|
self.commands[cmd.name] = cmd
|
||||||
|
|
||||||
|
def unregister(self, command_name: str):
|
||||||
|
if command_name in self.commands:
|
||||||
|
del self.commands[command_name]
|
||||||
|
else:
|
||||||
|
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||||
|
|
||||||
|
def reload_commands(self) -> None:
|
||||||
|
"""Reloads all loaded command plugins."""
|
||||||
|
for cmd_name in self.commands:
|
||||||
|
cmd = self.commands[cmd_name]
|
||||||
|
module = self._import_module(cmd.__module__)
|
||||||
|
reloaded_module = self._reload_module(module)
|
||||||
|
if hasattr(reloaded_module, "register"):
|
||||||
|
reloaded_module.register(self)
|
||||||
|
|
||||||
|
def get_command(self, name: str) -> Callable[..., Any]:
|
||||||
|
return self.commands[name]
|
||||||
|
|
||||||
|
def call(self, command_name: str, **kwargs) -> Any:
|
||||||
|
if command_name not in self.commands:
|
||||||
|
raise KeyError(f"Command '{command_name}' not found in registry.")
|
||||||
|
command = self.commands[command_name]
|
||||||
|
return command(**kwargs)
|
||||||
|
|
||||||
|
def command_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns a string representation of all registered `Command` objects for use in a prompt
|
||||||
|
"""
|
||||||
|
commands_list = [
|
||||||
|
f"{idx + 1}. {str(cmd)}" for idx, cmd in enumerate(self.commands.values())
|
||||||
|
]
|
||||||
|
return "\n".join(commands_list)
|
||||||
|
|
||||||
|
def import_commands(self, module_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Imports the specified Python module containing command plugins.
|
||||||
|
|
||||||
|
This method imports the associated module and registers any functions or
|
||||||
|
classes that are decorated with the `AUTO_GPT_COMMAND_IDENTIFIER` attribute
|
||||||
|
as `Command` objects. The registered `Command` objects are then added to the
|
||||||
|
`commands` dictionary of the `CommandRegistry` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name (str): The name of the module to import for command plugins.
|
||||||
|
"""
|
||||||
|
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
for attr_name in dir(module):
|
||||||
|
attr = getattr(module, attr_name)
|
||||||
|
# Register decorated functions
|
||||||
|
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
|
||||||
|
attr, AUTO_GPT_COMMAND_IDENTIFIER
|
||||||
|
):
|
||||||
|
self.register(attr.command)
|
||||||
|
# Register command classes
|
||||||
|
elif (
|
||||||
|
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
|
||||||
|
):
|
||||||
|
cmd_instance = attr()
|
||||||
|
self.register(cmd_instance)
|
||||||
|
|
||||||
|
|
||||||
|
def command(
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
signature: str = "",
|
||||||
|
enabled: bool = True,
|
||||||
|
disabled_reason: Optional[str] = None,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
|
"""The command decorator is used to create Command objects from ordinary functions."""
|
||||||
|
|
||||||
|
def decorator(func: Callable[..., Any]) -> Command:
|
||||||
|
cmd = Command(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
method=func,
|
||||||
|
signature=signature,
|
||||||
|
enabled=enabled,
|
||||||
|
disabled_reason=disabled_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs) -> Any:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
wrapper.command = cmd
|
||||||
|
|
||||||
|
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
296
pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
Normal file
296
pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
from pilot.base_modules.agent.commands.command_mange import command
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
import pandas as pd
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
|
import matplotlib
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.ticker as mtick
|
||||||
|
from matplotlib.font_manager import FontManager
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
||||||
|
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
||||||
|
|
||||||
|
|
||||||
|
def data_pre_classification(df: DataFrame):
|
||||||
|
## Data pre-classification
|
||||||
|
columns = df.columns.tolist()
|
||||||
|
|
||||||
|
number_columns = []
|
||||||
|
non_numeric_colums = []
|
||||||
|
|
||||||
|
# 收集数据分类小于10个的列
|
||||||
|
non_numeric_colums_value_map = {}
|
||||||
|
numeric_colums_value_map = {}
|
||||||
|
for column_name in columns:
|
||||||
|
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
|
||||||
|
number_columns.append(column_name)
|
||||||
|
unique_values = df[column_name].unique()
|
||||||
|
numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||||
|
else:
|
||||||
|
non_numeric_colums.append(column_name)
|
||||||
|
unique_values = df[column_name].unique()
|
||||||
|
non_numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||||
|
|
||||||
|
sorted_numeric_colums_value_map = dict(
|
||||||
|
sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||||
|
)
|
||||||
|
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
|
||||||
|
|
||||||
|
sorted_colums_value_map = dict(
|
||||||
|
sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||||
|
)
|
||||||
|
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
|
||||||
|
|
||||||
|
# Analyze x-coordinate
|
||||||
|
if len(non_numeric_colums_sort_list) > 0:
|
||||||
|
x_cloumn = non_numeric_colums_sort_list[-1]
|
||||||
|
non_numeric_colums_sort_list.remove(x_cloumn)
|
||||||
|
else:
|
||||||
|
x_cloumn = number_columns[0]
|
||||||
|
numeric_colums_sort_list.remove(x_cloumn)
|
||||||
|
|
||||||
|
# Analyze y-coordinate
|
||||||
|
if len(numeric_colums_sort_list) > 0:
|
||||||
|
y_column = numeric_colums_sort_list[0]
|
||||||
|
numeric_colums_sort_list.remove(y_column)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not enough numeric columns for chart!")
|
||||||
|
|
||||||
|
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
|
||||||
|
|
||||||
|
|
||||||
|
def zh_font_set():
|
||||||
|
font_names = [
|
||||||
|
"Heiti TC",
|
||||||
|
"Songti SC",
|
||||||
|
"STHeiti Light",
|
||||||
|
"Microsoft YaHei",
|
||||||
|
"SimSun",
|
||||||
|
"SimHei",
|
||||||
|
"KaiTi",
|
||||||
|
]
|
||||||
|
fm = FontManager()
|
||||||
|
mat_fonts = set(f.name for f in fm.ttflist)
|
||||||
|
can_use_fonts = []
|
||||||
|
for font_name in font_names:
|
||||||
|
if font_name in mat_fonts:
|
||||||
|
can_use_fonts.append(font_name)
|
||||||
|
if len(can_use_fonts) > 0:
|
||||||
|
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||||
|
|
||||||
|
|
||||||
|
@command(
|
||||||
|
"response_line_chart",
|
||||||
|
"Line chart display, used to display comparative trend analysis data",
|
||||||
|
'"speak": "<speak>", "df":"<data frame>"',
|
||||||
|
)
|
||||||
|
def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||||
|
logger.info(f"response_line_chart:{speak},")
|
||||||
|
if df.size <= 0:
|
||||||
|
raise ValueError("No Data!")
|
||||||
|
|
||||||
|
# set font
|
||||||
|
# zh_font_set()
|
||||||
|
font_names = [
|
||||||
|
"Heiti TC",
|
||||||
|
"Songti SC",
|
||||||
|
"STHeiti Light",
|
||||||
|
"Microsoft YaHei",
|
||||||
|
"SimSun",
|
||||||
|
"SimHei",
|
||||||
|
"KaiTi",
|
||||||
|
]
|
||||||
|
fm = FontManager()
|
||||||
|
mat_fonts = set(f.name for f in fm.ttflist)
|
||||||
|
can_use_fonts = []
|
||||||
|
for font_name in font_names:
|
||||||
|
if font_name in mat_fonts:
|
||||||
|
can_use_fonts.append(font_name)
|
||||||
|
if len(can_use_fonts) > 0:
|
||||||
|
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||||
|
|
||||||
|
rc = {"font.sans-serif": can_use_fonts}
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||||
|
|
||||||
|
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||||
|
sns.set_palette("Set3") # 设置颜色主题
|
||||||
|
sns.set_style("dark")
|
||||||
|
sns.color_palette("hls", 10)
|
||||||
|
sns.hls_palette(8, l=0.5, s=0.7)
|
||||||
|
sns.set(context="notebook", style="ticks", rc=rc)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||||
|
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||||
|
# ## 复杂折线图实现
|
||||||
|
if len(num_colmns) > 0:
|
||||||
|
num_colmns.append(y)
|
||||||
|
df_melted = pd.melt(
|
||||||
|
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
|
||||||
|
)
|
||||||
|
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
|
||||||
|
else:
|
||||||
|
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
|
||||||
|
|
||||||
|
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
|
||||||
|
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||||
|
|
||||||
|
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||||
|
chart_path = static_message_img_path + "/" + chart_name
|
||||||
|
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||||
|
|
||||||
|
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||||
|
return html_img
|
||||||
|
|
||||||
|
|
||||||
|
@command(
|
||||||
|
"response_bar_chart",
|
||||||
|
"Histogram, suitable for comparative analysis of multiple target values",
|
||||||
|
'"speak": "<speak>", "df":"<data frame>"',
|
||||||
|
)
|
||||||
|
def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||||
|
logger.info(f"response_bar_chart:{speak},")
|
||||||
|
if df.size <= 0:
|
||||||
|
raise ValueError("No Data!")
|
||||||
|
|
||||||
|
# set font
|
||||||
|
# zh_font_set()
|
||||||
|
font_names = [
|
||||||
|
"Heiti TC",
|
||||||
|
"Songti SC",
|
||||||
|
"STHeiti Light",
|
||||||
|
"Microsoft YaHei",
|
||||||
|
"SimSun",
|
||||||
|
"SimHei",
|
||||||
|
"KaiTi",
|
||||||
|
]
|
||||||
|
fm = FontManager()
|
||||||
|
mat_fonts = set(f.name for f in fm.ttflist)
|
||||||
|
can_use_fonts = []
|
||||||
|
for font_name in font_names:
|
||||||
|
if font_name in mat_fonts:
|
||||||
|
can_use_fonts.append(font_name)
|
||||||
|
if len(can_use_fonts) > 0:
|
||||||
|
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||||
|
|
||||||
|
rc = {"font.sans-serif": can_use_fonts}
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||||
|
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||||
|
sns.set_palette("Set3") # 设置颜色主题
|
||||||
|
sns.set_style("dark")
|
||||||
|
sns.color_palette("hls", 10)
|
||||||
|
sns.hls_palette(8, l=0.5, s=0.7)
|
||||||
|
sns.set(context="notebook", style="ticks", rc=rc)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||||
|
|
||||||
|
hue = None
|
||||||
|
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||||
|
if len(non_num_columns) >= 1:
|
||||||
|
hue = non_num_columns[0]
|
||||||
|
|
||||||
|
if len(num_colmns) >= 1:
|
||||||
|
if hue:
|
||||||
|
if len(num_colmns) >= 2:
|
||||||
|
can_use_columns = num_colmns[:2]
|
||||||
|
else:
|
||||||
|
can_use_columns = num_colmns
|
||||||
|
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||||
|
for sub_y_column in can_use_columns:
|
||||||
|
sns.barplot(
|
||||||
|
data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if len(num_colmns) > 5:
|
||||||
|
can_use_columns = num_colmns[:5]
|
||||||
|
else:
|
||||||
|
can_use_columns = num_colmns
|
||||||
|
can_use_columns.append(y)
|
||||||
|
|
||||||
|
df_melted = pd.melt(
|
||||||
|
df,
|
||||||
|
id_vars=x,
|
||||||
|
value_vars=can_use_columns,
|
||||||
|
var_name="line",
|
||||||
|
value_name="Value",
|
||||||
|
)
|
||||||
|
sns.barplot(
|
||||||
|
data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||||
|
|
||||||
|
# 设置 y 轴刻度格式为普通数字格式
|
||||||
|
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
|
||||||
|
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||||
|
|
||||||
|
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
|
||||||
|
chart_path = static_message_img_path + "/" + chart_name
|
||||||
|
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||||
|
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||||
|
return html_img
|
||||||
|
|
||||||
|
|
||||||
|
@command(
|
||||||
|
"response_pie_chart",
|
||||||
|
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||||
|
'"speak": "<speak>", "df":"<data frame>"',
|
||||||
|
)
|
||||||
|
def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||||
|
logger.info(f"response_pie_chart:{speak},")
|
||||||
|
columns = df.columns.tolist()
|
||||||
|
if df.size <= 0:
|
||||||
|
raise ValueError("No Data!")
|
||||||
|
# set font
|
||||||
|
# zh_font_set()
|
||||||
|
font_names = [
|
||||||
|
"Heiti TC",
|
||||||
|
"Songti SC",
|
||||||
|
"STHeiti Light",
|
||||||
|
"Microsoft YaHei",
|
||||||
|
"SimSun",
|
||||||
|
"SimHei",
|
||||||
|
"KaiTi",
|
||||||
|
]
|
||||||
|
fm = FontManager()
|
||||||
|
mat_fonts = set(f.name for f in fm.ttflist)
|
||||||
|
can_use_fonts = []
|
||||||
|
for font_name in font_names:
|
||||||
|
if font_name in mat_fonts:
|
||||||
|
can_use_fonts.append(font_name)
|
||||||
|
if len(can_use_fonts) > 0:
|
||||||
|
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||||
|
|
||||||
|
sns.set_palette("Set3") # 设置颜色主题
|
||||||
|
|
||||||
|
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||||
|
ax = df.plot(
|
||||||
|
kind="pie",
|
||||||
|
y=columns[1],
|
||||||
|
ax=ax,
|
||||||
|
labels=df[columns[0]].values,
|
||||||
|
startangle=90,
|
||||||
|
autopct="%1.1f%%",
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.axis("equal") # 使饼图为正圆形
|
||||||
|
# plt.title(columns[0])
|
||||||
|
|
||||||
|
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||||
|
chart_path = static_message_img_path + "/" + chart_name
|
||||||
|
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||||
|
|
||||||
|
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||||
|
|
||||||
|
return html_img
|
@ -0,0 +1,24 @@
|
|||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
from pilot.base_modules.agent.commands.command_mange import command
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||||
|
|
||||||
|
|
||||||
|
@command(
|
||||||
|
"response_table",
|
||||||
|
"Table display, suitable for display with many display columns or non-numeric columns",
|
||||||
|
'"speak": "<speak>", "df":"<data frame>"',
|
||||||
|
)
|
||||||
|
def response_table(speak: str, df: DataFrame) -> str:
|
||||||
|
logger.info(f"response_table:{speak}")
|
||||||
|
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||||
|
table_str = "".join(html_table.split())
|
||||||
|
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||||
|
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||||
|
return view_text
|
@ -0,0 +1,39 @@
|
|||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
|
from pilot.base_modules.agent.commands.command_mange import command
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||||
|
|
||||||
|
|
||||||
|
@command(
|
||||||
|
"response_data_text",
|
||||||
|
"Text display, the default display method, suitable for single-line or simple content display",
|
||||||
|
'"speak": "<speak>", "df":"<data frame>"',
|
||||||
|
)
|
||||||
|
def response_data_text(speak: str, df: DataFrame) -> str:
|
||||||
|
logger.info(f"response_data_text:{speak}")
|
||||||
|
data = df.values
|
||||||
|
|
||||||
|
row_size = data.shape[0]
|
||||||
|
value_str = ""
|
||||||
|
text_info = ""
|
||||||
|
if row_size > 1:
|
||||||
|
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||||
|
table_str = "".join(html_table.split())
|
||||||
|
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||||
|
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||||
|
elif row_size == 1:
|
||||||
|
row = data[0]
|
||||||
|
for value in row:
|
||||||
|
if value_str:
|
||||||
|
value_str = value_str + f", ** {value} **"
|
||||||
|
else:
|
||||||
|
value_str = f" ** {value} **"
|
||||||
|
text_info = f"{speak}: {value_str}"
|
||||||
|
else:
|
||||||
|
text_info = f"##### {speak}: _没有找到可用的数据_"
|
||||||
|
return text_info
|
@ -0,0 +1,4 @@
|
|||||||
|
class NotCommands(Exception):
|
||||||
|
def __init__(self, message):
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
158
pilot/base_modules/agent/commands/generator.py
Normal file
158
pilot/base_modules/agent/commands/generator.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
""" A module for generating custom prompt strings."""
|
||||||
|
import json
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class PluginPromptGenerator:
|
||||||
|
"""
|
||||||
|
A class for generating custom prompt strings based on constraints, commands,
|
||||||
|
resources, and performance evaluations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the PromptGenerator object with empty lists of constraints,
|
||||||
|
commands, resources, and performance evaluations.
|
||||||
|
"""
|
||||||
|
self.constraints = []
|
||||||
|
self.commands = []
|
||||||
|
self.resources = []
|
||||||
|
self.performance_evaluation = []
|
||||||
|
self.goals = []
|
||||||
|
self.command_registry = None
|
||||||
|
self.name = "Bob"
|
||||||
|
self.role = "AI"
|
||||||
|
self.response_format = {
|
||||||
|
"thoughts": {
|
||||||
|
"text": "thought",
|
||||||
|
"reasoning": "reasoning",
|
||||||
|
"plan": "- short bulleted\n- list that conveys\n- long-term plan",
|
||||||
|
"criticism": "constructive self-criticism",
|
||||||
|
"speak": "thoughts summary to say to user",
|
||||||
|
},
|
||||||
|
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_constraint(self, constraint: str) -> None:
|
||||||
|
"""
|
||||||
|
Add a constraint to the constraints list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
constraint (str): The constraint to be added.
|
||||||
|
"""
|
||||||
|
self.constraints.append(constraint)
|
||||||
|
|
||||||
|
def add_command(
|
||||||
|
self,
|
||||||
|
command_label: str,
|
||||||
|
command_name: str,
|
||||||
|
args=None,
|
||||||
|
function: Optional[Callable] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add a command to the commands list with a label, name, and optional arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command_label (str): The label of the command.
|
||||||
|
command_name (str): The name of the command.
|
||||||
|
args (dict, optional): A dictionary containing argument names and their
|
||||||
|
values. Defaults to None.
|
||||||
|
function (callable, optional): A callable function to be called when
|
||||||
|
the command is executed. Defaults to None.
|
||||||
|
"""
|
||||||
|
if args is None:
|
||||||
|
args = {}
|
||||||
|
|
||||||
|
command_args = {arg_key: arg_value for arg_key, arg_value in args.items()}
|
||||||
|
|
||||||
|
command = {
|
||||||
|
"label": command_label,
|
||||||
|
"name": command_name,
|
||||||
|
"args": command_args,
|
||||||
|
"function": function,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.commands.append(command)
|
||||||
|
|
||||||
|
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Generate a formatted string representation of a command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command (dict): A dictionary containing command information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The formatted command string.
|
||||||
|
"""
|
||||||
|
args_string = ", ".join(
|
||||||
|
f'"{key}": "{value}"' for key, value in command["args"].items()
|
||||||
|
)
|
||||||
|
return f'{command["label"]}: "{command["name"]}", args: {args_string}'
|
||||||
|
|
||||||
|
def add_resource(self, resource: str) -> None:
|
||||||
|
"""
|
||||||
|
Add a resource to the resources list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource (str): The resource to be added.
|
||||||
|
"""
|
||||||
|
self.resources.append(resource)
|
||||||
|
|
||||||
|
def add_performance_evaluation(self, evaluation: str) -> None:
|
||||||
|
"""
|
||||||
|
Add a performance evaluation item to the performance_evaluation list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
evaluation (str): The evaluation item to be added.
|
||||||
|
"""
|
||||||
|
self.performance_evaluation.append(evaluation)
|
||||||
|
|
||||||
|
def _generate_numbered_list(self, items: List[Any], item_type="list") -> str:
|
||||||
|
"""
|
||||||
|
Generate a numbered list from given items based on the item_type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (list): A list of items to be numbered.
|
||||||
|
item_type (str, optional): The type of items in the list.
|
||||||
|
Defaults to 'list'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The formatted numbered list.
|
||||||
|
"""
|
||||||
|
if item_type == "command":
|
||||||
|
command_strings = []
|
||||||
|
if self.command_registry:
|
||||||
|
command_strings += [
|
||||||
|
str(item)
|
||||||
|
for item in self.command_registry.commands.values()
|
||||||
|
if item.enabled
|
||||||
|
]
|
||||||
|
# terminate command is added manually
|
||||||
|
command_strings += [self._generate_command_string(item) for item in items]
|
||||||
|
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||||
|
else:
|
||||||
|
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(items))
|
||||||
|
|
||||||
|
def generate_commands_string(self) -> str:
|
||||||
|
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
||||||
|
|
||||||
|
def generate_prompt_string(self) -> str:
|
||||||
|
"""
|
||||||
|
Generate a prompt string based on the constraints, commands, resources,
|
||||||
|
and performance evaluations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated prompt string.
|
||||||
|
"""
|
||||||
|
formatted_response_format = json.dumps(self.response_format, indent=4)
|
||||||
|
return (
|
||||||
|
f"Constraints:\n{self._generate_numbered_list(self.constraints)}\n\n"
|
||||||
|
"Commands:\n"
|
||||||
|
f"{self._generate_numbered_list(self.commands, item_type='command')}\n\n"
|
||||||
|
f"Resources:\n{self._generate_numbered_list(self.resources)}\n\n"
|
||||||
|
"Performance Evaluation:\n"
|
||||||
|
f"{self._generate_numbered_list(self.performance_evaluation)}\n\n"
|
||||||
|
"You should only respond in JSON format as described below and ensure the"
|
||||||
|
"response can be parsed by Python json.loads \nResponse"
|
||||||
|
f" Format: \n{formatted_response_format}"
|
||||||
|
)
|
10
pilot/base_modules/agent/commands/times.py
Normal file
10
pilot/base_modules/agent/commands/times.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def get_datetime() -> str:
|
||||||
|
"""Return the current date and time
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The current date and time
|
||||||
|
"""
|
||||||
|
return "Current date and time: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
26
pilot/base_modules/agent/controller.py
Normal file
26
pilot/base_modules/agent/controller.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Body,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
from pilot.openapi.api_view_model import (
|
||||||
|
Result,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/mange/agent/list", response_model=Result[str])
|
||||||
|
async def get_agent_list():
|
||||||
|
logger.info(f"get_agent_list!")
|
||||||
|
|
||||||
|
return Result.succ(None)
|
136
pilot/base_modules/agent/db/my_plugin_db.py
Normal file
136
pilot/base_modules/agent/db/my_plugin_db.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
from sqlalchemy import Column, Integer, String, Index, DateTime, func
|
||||||
|
from sqlalchemy import UniqueConstraint
|
||||||
|
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base
|
||||||
|
|
||||||
|
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||||
|
|
||||||
|
|
||||||
|
class MyPluginEntity(Base):
|
||||||
|
__tablename__ = 'my_plugin'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||||
|
tenant = Column(String, nullable=True, comment="user's tenant")
|
||||||
|
user_code = Column(String, nullable=True, comment="user code")
|
||||||
|
user_name = Column(String, nullable=True, comment="user name")
|
||||||
|
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||||
|
type = Column(String, comment="plugin type")
|
||||||
|
version = Column(String, comment="plugin version")
|
||||||
|
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
|
||||||
|
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count")
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint('name', name="uk_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
database="dbgpt", orm_base=Base, db_engine =engine , session= session
|
||||||
|
)
|
||||||
|
|
||||||
|
def add(self, engity: MyPluginEntity):
|
||||||
|
session = self.Session()
|
||||||
|
my_plugin = MyPluginEntity(
|
||||||
|
tenant=engity.tenant,
|
||||||
|
user_code=engity.user_code,
|
||||||
|
user_name=engity.user_name,
|
||||||
|
name=engity.name,
|
||||||
|
type=engity.type,
|
||||||
|
version=engity.version,
|
||||||
|
use_count=engity.use_count or 0,
|
||||||
|
succ_count=engity.succ_count or 0,
|
||||||
|
created_at=datetime.now(),
|
||||||
|
)
|
||||||
|
session.add(my_plugin)
|
||||||
|
session.commit()
|
||||||
|
id = my_plugin.id
|
||||||
|
session.close()
|
||||||
|
return id
|
||||||
|
|
||||||
|
def update(self, entity: MyPluginEntity):
|
||||||
|
session = self.Session()
|
||||||
|
updated = session.merge(entity)
|
||||||
|
session.commit()
|
||||||
|
return updated.id
|
||||||
|
|
||||||
|
|
||||||
|
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
|
||||||
|
session = self.Session()
|
||||||
|
my_plugins = session.query(MyPluginEntity)
|
||||||
|
if query.id is not None:
|
||||||
|
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||||
|
if query.name is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.name == query.name
|
||||||
|
)
|
||||||
|
if query.tenant is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.tenant == query.tenant
|
||||||
|
)
|
||||||
|
if query.type is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.type == query.type
|
||||||
|
)
|
||||||
|
if query.user_code is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.user_code == query.user_code
|
||||||
|
)
|
||||||
|
if query.user_name is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.user_name == query.user_name
|
||||||
|
)
|
||||||
|
|
||||||
|
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
|
||||||
|
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
|
||||||
|
result = my_plugins.all()
|
||||||
|
session.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def count(self, query: MyPluginEntity):
|
||||||
|
session = self.Session()
|
||||||
|
my_plugins = session.query(func.count(MyPluginEntity.id))
|
||||||
|
if query.id is not None:
|
||||||
|
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||||
|
if query.name is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.name == query.name
|
||||||
|
)
|
||||||
|
if query.type is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.type == query.type
|
||||||
|
)
|
||||||
|
if query.tenant is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.tenant == query.tenant
|
||||||
|
)
|
||||||
|
if query.user_code is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.user_code == query.user_code
|
||||||
|
)
|
||||||
|
if query.user_name is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.user_name == query.user_name
|
||||||
|
)
|
||||||
|
count = my_plugins.scalar()
|
||||||
|
session.close()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def delete(self, plugin_id: int):
|
||||||
|
session = self.Session()
|
||||||
|
if plugin_id is None:
|
||||||
|
raise Exception("plugin_id is None")
|
||||||
|
query = MyPluginEntity(id=plugin_id)
|
||||||
|
my_plugins = session.query(MyPluginEntity)
|
||||||
|
if query.id is not None:
|
||||||
|
my_plugins = my_plugins.filter(
|
||||||
|
MyPluginEntity.id == query.id
|
||||||
|
)
|
||||||
|
my_plugins.delete()
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
136
pilot/base_modules/agent/db/plugin_hub_db.py
Normal file
136
pilot/base_modules/agent/db/plugin_hub_db.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List
|
||||||
|
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean
|
||||||
|
from sqlalchemy import UniqueConstraint
|
||||||
|
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base
|
||||||
|
|
||||||
|
from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||||
|
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||||
|
|
||||||
|
|
||||||
|
class PluginHubEntity(Base):
|
||||||
|
__tablename__ = 'plugin_hub'
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||||
|
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||||
|
author = Column(String, nullable=True, comment="plugin author")
|
||||||
|
email = Column(String, nullable=True, comment="plugin author email")
|
||||||
|
type = Column(String, comment="plugin type")
|
||||||
|
version = Column(String, comment="plugin version")
|
||||||
|
storage_channel = Column(String, comment="plugin storage channel")
|
||||||
|
storage_url = Column(String, comment="plugin download url")
|
||||||
|
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
|
||||||
|
installed = Column(Boolean, default=False, comment="plugin already installed")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint('name', name="uk_name"),
|
||||||
|
Index('idx_q_type', 'type'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||||
|
)
|
||||||
|
|
||||||
|
def add(self, engity: PluginHubEntity):
|
||||||
|
session = self.Session()
|
||||||
|
plugin_hub = PluginHubEntity(
|
||||||
|
name=engity.name,
|
||||||
|
author=engity.author,
|
||||||
|
email=engity.email,
|
||||||
|
type=engity.type,
|
||||||
|
version=engity.version,
|
||||||
|
storage_channel=engity.storage_channel,
|
||||||
|
storage_url=engity.storage_url,
|
||||||
|
created_at=datetime.now(),
|
||||||
|
)
|
||||||
|
session.add(plugin_hub)
|
||||||
|
session.commit()
|
||||||
|
id = plugin_hub.id
|
||||||
|
session.close()
|
||||||
|
return id
|
||||||
|
|
||||||
|
def update(self, entity: PluginHubEntity):
|
||||||
|
session = self.Session()
|
||||||
|
updated = session.merge(entity)
|
||||||
|
session.commit()
|
||||||
|
return updated.id
|
||||||
|
|
||||||
|
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
|
||||||
|
session = self.Session()
|
||||||
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
|
if query.id is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||||
|
if query.name is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.name == query.name
|
||||||
|
)
|
||||||
|
if query.type is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.type == query.type
|
||||||
|
)
|
||||||
|
if query.author is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.author == query.author
|
||||||
|
)
|
||||||
|
if query.storage_channel is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.storage_channel == query.storage_channel
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin_hubs = plugin_hubs.order_by(PluginHubEntity.id.desc())
|
||||||
|
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
|
||||||
|
result = plugin_hubs.all()
|
||||||
|
session.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||||
|
session = self.Session()
|
||||||
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.name == name
|
||||||
|
)
|
||||||
|
result = plugin_hubs.get(1)
|
||||||
|
session.close()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def count(self, query: PluginHubEntity):
|
||||||
|
session = self.Session()
|
||||||
|
plugin_hubs = session.query(func.count(PluginHubEntity.id))
|
||||||
|
if query.id is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||||
|
if query.name is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.name == query.name
|
||||||
|
)
|
||||||
|
if query.type is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.type == query.type
|
||||||
|
)
|
||||||
|
if query.author is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.author == query.author
|
||||||
|
)
|
||||||
|
if query.storage_channel is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.storage_channel == query.storage_channel
|
||||||
|
)
|
||||||
|
count = plugin_hubs.scalar()
|
||||||
|
session.close()
|
||||||
|
return count
|
||||||
|
|
||||||
|
def delete(self, plugin_id: int):
|
||||||
|
session = self.Session()
|
||||||
|
if plugin_id is None:
|
||||||
|
raise Exception("plugin_id is None")
|
||||||
|
query = PluginHubEntity(id=plugin_id)
|
||||||
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
|
if query.id is not None:
|
||||||
|
plugin_hubs = plugin_hubs.filter(
|
||||||
|
PluginHubEntity.id == query.id
|
||||||
|
)
|
||||||
|
plugin_hubs.delete()
|
||||||
|
session.commit()
|
||||||
|
session.close()
|
0
pilot/base_modules/agent/hub/__init__.py
Normal file
0
pilot/base_modules/agent/hub/__init__.py
Normal file
69
pilot/base_modules/agent/hub/agent_hub.py
Normal file
69
pilot/base_modules/agent/hub/agent_hub.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||||
|
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
||||||
|
from .schema import PluginStorageType
|
||||||
|
|
||||||
|
logger = logging.getLogger("agent_hub")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentHub:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.hub_dao = PluginHubDao()
|
||||||
|
self.my_lugin_dao = MyPluginDao()
|
||||||
|
|
||||||
|
def install_plugin(self, plugin_name: str, user_name: str = None):
|
||||||
|
logger.info(f"install_plugin {plugin_name}")
|
||||||
|
|
||||||
|
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||||
|
if plugin_entity:
|
||||||
|
if plugin_entity.storage_channel == PluginStorageType.Git.value:
|
||||||
|
try:
|
||||||
|
self.__download_from_git(plugin_name, plugin_entity.storage_url)
|
||||||
|
self.load_plugin(plugin_name)
|
||||||
|
|
||||||
|
# add to my plugins and edit hub status
|
||||||
|
plugin_entity.installed = True
|
||||||
|
|
||||||
|
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
||||||
|
if not user_name:
|
||||||
|
# TODO use user
|
||||||
|
my_plugin_entity.user_code = ""
|
||||||
|
my_plugin_entity.user_name = user_name
|
||||||
|
my_plugin_entity.tenant = ""
|
||||||
|
|
||||||
|
with self.hub_dao.Session() as session:
|
||||||
|
try:
|
||||||
|
session.add(my_plugin_entity)
|
||||||
|
session.merge(plugin_entity)
|
||||||
|
session.commit()
|
||||||
|
except:
|
||||||
|
session.rollback()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("install pluguin exception!", e)
|
||||||
|
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
||||||
|
|
||||||
|
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
|
||||||
|
my_plugin_entity = MyPluginEntity()
|
||||||
|
my_plugin_entity.name = hub_plugin.name
|
||||||
|
my_plugin_entity.type = hub_plugin.type
|
||||||
|
my_plugin_entity.version = hub_plugin.version
|
||||||
|
return my_plugin_entity
|
||||||
|
|
||||||
|
def __download_from_git(self, plugin_name, url):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_plugin(self, plugin_name):
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_my_plugin(self, user: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def uninstall_plugin(self):
|
||||||
|
pass
|
6
pilot/base_modules/agent/hub/schema.py
Normal file
6
pilot/base_modules/agent/hub/schema.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class PluginStorageType(Enum):
|
||||||
|
Git = "git"
|
||||||
|
Oss = "oss"
|
182
pilot/base_modules/agent/plugins.py
Normal file
182
pilot/base_modules/agent/plugins.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
"""加载组件"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import zipfile
|
||||||
|
import requests
|
||||||
|
import threading
|
||||||
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from zipimport import zipimporter
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||||
|
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import PLUGINS_DIR
|
||||||
|
from pilot.logs import logger
|
||||||
|
|
||||||
|
|
||||||
|
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||||
|
"""
|
||||||
|
Loader zip plugin file. Native support Auto_gpt_plugin
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_path (str): Path to the zipfile.
|
||||||
|
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: The list of module names found or empty list if none were found.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zfile:
|
||||||
|
for name in zfile.namelist():
|
||||||
|
if name.endswith("__init__.py") and not name.startswith("__MACOSX"):
|
||||||
|
logger.debug(f"Found module '{name}' in the zipfile at: {name}")
|
||||||
|
result.append(name)
|
||||||
|
if len(result) == 0:
|
||||||
|
logger.debug(f"Module '__init__.py' not found in the zipfile @ {zip_path}.")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def write_dict_to_json_file(data: dict, file_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Write a dictionary to a JSON file.
|
||||||
|
Args:
|
||||||
|
data (dict): Dictionary to write.
|
||||||
|
file_path (str): Path to the file.
|
||||||
|
"""
|
||||||
|
with open(file_path, "w") as file:
|
||||||
|
json.dump(data, file, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def create_directory_if_not_exists(directory_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Create a directory if it does not exist.
|
||||||
|
Args:
|
||||||
|
directory_path (str): Path to the directory.
|
||||||
|
Returns:
|
||||||
|
bool: True if the directory was created, else False.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(directory_path):
|
||||||
|
try:
|
||||||
|
os.makedirs(directory_path)
|
||||||
|
logger.debug(f"Created directory: {directory_path}")
|
||||||
|
return True
|
||||||
|
except OSError as e:
|
||||||
|
logger.warn(f"Error creating directory {directory_path}: {e}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.info(f"Directory {directory_path} already exists")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def load_native_plugins(cfg: Config):
|
||||||
|
if not cfg.plugins_auto_load:
|
||||||
|
print("not auto load_native_plugins")
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_from_git(cfg: Config):
|
||||||
|
print("async load_native_plugins")
|
||||||
|
branch_name = cfg.plugins_git_branch
|
||||||
|
native_plugin_repo = "DB-GPT-Plugins"
|
||||||
|
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||||
|
try:
|
||||||
|
session = requests.Session()
|
||||||
|
response = session.get(
|
||||||
|
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||||
|
headers={"Authorization": "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"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
plugins_path_path = Path(PLUGINS_DIR)
|
||||||
|
files = glob.glob(
|
||||||
|
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
|
||||||
|
)
|
||||||
|
for file in files:
|
||||||
|
os.remove(file)
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||||
|
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||||
|
print(file_name)
|
||||||
|
with open(file_name, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
print("save file")
|
||||||
|
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||||
|
else:
|
||||||
|
print("get file faild,response code:", response.status_code)
|
||||||
|
except Exception as e:
|
||||||
|
print("load plugin from git exception!" + str(e))
|
||||||
|
|
||||||
|
t = threading.Thread(target=load_from_git, args=(cfg,))
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||||
|
"""Scan the plugins directory for plugins and loads them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (Config): Config instance including plugins config
|
||||||
|
debug (bool, optional): Enable debug logging. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[str, Path]]: List of plugins.
|
||||||
|
"""
|
||||||
|
loaded_plugins = []
|
||||||
|
current_dir = os.getcwd()
|
||||||
|
print(current_dir)
|
||||||
|
# Generic plugins
|
||||||
|
plugins_path_path = Path(PLUGINS_DIR)
|
||||||
|
|
||||||
|
for plugin in plugins_path_path.glob("*.zip"):
|
||||||
|
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
||||||
|
for module in moduleList:
|
||||||
|
plugin = Path(plugin)
|
||||||
|
module = Path(module)
|
||||||
|
logger.debug(f"Plugin: {plugin} Module: {module}")
|
||||||
|
zipped_package = zipimporter(str(plugin))
|
||||||
|
zipped_module = zipped_package.load_module(str(module.parent))
|
||||||
|
for key in dir(zipped_module):
|
||||||
|
if key.startswith("__"):
|
||||||
|
continue
|
||||||
|
a_module = getattr(zipped_module, key)
|
||||||
|
a_keys = dir(a_module)
|
||||||
|
if (
|
||||||
|
"_abc_impl" in a_keys
|
||||||
|
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||||
|
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||||
|
):
|
||||||
|
loaded_plugins.append(a_module())
|
||||||
|
|
||||||
|
if loaded_plugins:
|
||||||
|
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||||
|
for plugin in loaded_plugins:
|
||||||
|
logger.info(f"{plugin._name}: {plugin._version} - {plugin._description}")
|
||||||
|
return loaded_plugins
|
||||||
|
|
||||||
|
|
||||||
|
def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
||||||
|
"""Check if the plugin is in the allowlist or denylist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_name (str): Name of the plugin.
|
||||||
|
cfg (Config): Config object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True or False
|
||||||
|
"""
|
||||||
|
logger.debug(f"Checking if plugin {plugin_name} should be loaded")
|
||||||
|
if plugin_name in cfg.plugins_denylist:
|
||||||
|
logger.debug(f"Not loading plugin {plugin_name} as it was in the denylist.")
|
||||||
|
return False
|
||||||
|
if plugin_name in cfg.plugins_allowlist:
|
||||||
|
logger.debug(f"Loading plugin {plugin_name} as it was in the allowlist.")
|
||||||
|
return True
|
||||||
|
ack = input(
|
||||||
|
f"WARNING: Plugin {plugin_name} found. But not in the"
|
||||||
|
f" allowlist... Load? ({cfg.authorise_key}/{cfg.exit_key}): "
|
||||||
|
)
|
||||||
|
return ack.lower() == cfg.authorise_key
|
2
pilot/base_modules/agent/requirement.txt
Normal file
2
pilot/base_modules/agent/requirement.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
flask_sqlalchemy==3.0.5
|
||||||
|
flask==2.3.2
|
0
pilot/base_modules/base.py
Normal file
0
pilot/base_modules/base.py
Normal file
7
pilot/base_modules/mange_base_api.py
Normal file
7
pilot/base_modules/mange_base_api.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
class ModuleMangeApi:
|
||||||
|
|
||||||
|
def module_name(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register(self):
|
||||||
|
pass
|
0
pilot/base_modules/meta_data/__init__.py
Normal file
0
pilot/base_modules/meta_data/__init__.py
Normal file
21
pilot/base_modules/meta_data/base_dao.py
Normal file
21
pilot/base_modules/meta_data/base_dao.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from typing import TypeVar, Generic, List, Any
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
class BaseDao(Generic[T]):
|
||||||
|
def __init__(
|
||||||
|
self, orm_base=None, database: str = None, db_engine: Any = None, session: Any = None,
|
||||||
|
) -> None:
|
||||||
|
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
||||||
|
self._orm_base = orm_base
|
||||||
|
self._database = database
|
||||||
|
|
||||||
|
self._db_engine = db_engine
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
@property
|
||||||
|
def Session(self):
|
||||||
|
if not self._session:
|
||||||
|
self._session = sessionmaker(bind=self.db_engine)
|
||||||
|
return self._session
|
109
pilot/base_modules/meta_data/meta_data.py
Normal file
109
pilot/base_modules/meta_data/meta_data.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import uuid
|
||||||
|
import os
|
||||||
|
import duckdb
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Type, TypeVar
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
|
from flask_migrate import Migrate,upgrade
|
||||||
|
from flask.cli import with_appcontext
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine,DateTime, String, func, MetaData
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import Mapped
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from alembic import context, command
|
||||||
|
from alembic.config import Config
|
||||||
|
|
||||||
|
default_db_path = os.path.join(os.getcwd(), "meta_data")
|
||||||
|
|
||||||
|
os.makedirs(default_db_path, exist_ok=True)
|
||||||
|
|
||||||
|
db_path = default_db_path + "/dbgpt.db"
|
||||||
|
connection = sqlite3.connect(db_path)
|
||||||
|
engine = create_engine(f'sqlite:///{db_path}')
|
||||||
|
|
||||||
|
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
session = Session()
|
||||||
|
|
||||||
|
Base = declarative_base(bind=engine)
|
||||||
|
|
||||||
|
# Base.metadata.create_all()
|
||||||
|
|
||||||
|
# 创建Alembic配置对象
|
||||||
|
|
||||||
|
alembic_ini_path = default_db_path + "/alembic.ini"
|
||||||
|
alembic_cfg = Config(alembic_ini_path)
|
||||||
|
|
||||||
|
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
|
||||||
|
|
||||||
|
os.makedirs(default_db_path + "/alembic", exist_ok=True)
|
||||||
|
alembic_cfg.set_main_option('script_location', default_db_path + "/alembic")
|
||||||
|
|
||||||
|
# 将模型和会话传递给Alembic配置
|
||||||
|
alembic_cfg.attributes['target_metadata'] = Base.metadata
|
||||||
|
alembic_cfg.attributes['session'] = session
|
||||||
|
|
||||||
|
|
||||||
|
# # 创建表
|
||||||
|
# Base.metadata.create_all(engine)
|
||||||
|
#
|
||||||
|
# # 删除表
|
||||||
|
# Base.metadata.drop_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
# app = Flask(__name__)
|
||||||
|
# default_db_path = os.path.join(os.getcwd(), "meta_data")
|
||||||
|
# duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/dbgpt.db")
|
||||||
|
# app.config['SQLALCHEMY_DATABASE_URI'] = f'duckdb://{duckdb_path}'
|
||||||
|
# app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
|
||||||
|
# db = SQLAlchemy(app)
|
||||||
|
# migrate = Migrate(app, db)
|
||||||
|
#
|
||||||
|
# # 设置FLASK_APP环境变量
|
||||||
|
# import os
|
||||||
|
# os.environ['FLASK_APP'] = 'server.dbgpt_server.py'
|
||||||
|
#
|
||||||
|
# @app.cli.command("db_init")
|
||||||
|
# @with_appcontext
|
||||||
|
# def db_init():
|
||||||
|
# subprocess.run(["flask", "db", "init"])
|
||||||
|
#
|
||||||
|
# @app.cli.command("db_migrate")
|
||||||
|
# @with_appcontext
|
||||||
|
# def db_migrate():
|
||||||
|
# subprocess.run(["flask", "db", "migrate"])
|
||||||
|
#
|
||||||
|
# @app.cli.command("db_upgrade")
|
||||||
|
# @with_appcontext
|
||||||
|
# def db_upgrade():
|
||||||
|
# subprocess.run(["flask", "db", "upgrade"])
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def ddl_init_and_upgrade():
|
||||||
|
# Base.metadata.create_all(bind=engine)
|
||||||
|
# 生成并应用迁移脚本
|
||||||
|
# command.upgrade(alembic_cfg, 'head')
|
||||||
|
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
|
||||||
|
with engine.connect() as connection:
|
||||||
|
alembic_cfg.attributes['connection'] = connection
|
||||||
|
command.revision(alembic_cfg, "test", True)
|
||||||
|
command.upgrade(alembic_cfg, "head")
|
||||||
|
# alembic_cfg.attributes['connection'] = engine.connect()
|
||||||
|
# command.upgrade(alembic_cfg, 'head')
|
||||||
|
|
||||||
|
# with app.app_context():
|
||||||
|
# db_init()
|
||||||
|
# db_migrate()
|
||||||
|
# db_upgrade()
|
||||||
|
|
||||||
|
|
1
pilot/base_modules/meta_data/requirement.txt
Normal file
1
pilot/base_modules/meta_data/requirement.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
alembic==1.12.0
|
0
pilot/base_modules/module_factory.py
Normal file
0
pilot/base_modules/module_factory.py
Normal file
Loading…
Reference in New Issue
Block a user