mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-03 09:34:04 +00:00
264 lines
10 KiB
Python
264 lines
10 KiB
Python
import glob
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from fastapi import UploadFile
|
|
|
|
from dbgpt.agent.core.schema import PluginStorageType
|
|
from dbgpt.agent.resource.tool.autogpt.plugins_util import scan_plugins, update_from_git
|
|
from dbgpt.configs.model_config import PLUGINS_DIR
|
|
from dbgpt.serve.agent.hub.db.my_plugin_db import MyPluginDao, MyPluginEntity
|
|
from dbgpt.serve.agent.hub.db.plugin_hub_db import PluginHubDao, PluginHubEntity
|
|
|
|
logger = logging.getLogger(__name__)
|
|
Default_User = "default"
|
|
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
|
|
TEMP_PLUGIN_PATH = ""
|
|
|
|
|
|
class PluginHub:
|
|
def __init__(self, plugin_dir) -> None:
|
|
self.hub_dao = PluginHubDao()
|
|
self.my_plugin_dao = MyPluginDao()
|
|
os.makedirs(plugin_dir, exist_ok=True)
|
|
self.plugin_dir = plugin_dir
|
|
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
|
|
|
|
def install_plugin(self, plugin_name: str, user_name: str = None):
|
|
logger.info(f"install_plugin {plugin_name}")
|
|
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
|
if plugin_entity:
|
|
if plugin_entity.storage_channel == PluginStorageType.Git.value:
|
|
try:
|
|
branch_name = None
|
|
authorization = None
|
|
if plugin_entity.download_param:
|
|
download_param = json.loads(plugin_entity.download_param)
|
|
branch_name = download_param.get("branch_name")
|
|
authorization = download_param.get("authorization")
|
|
file_name = self.__download_from_git(
|
|
plugin_entity.storage_url, branch_name, authorization
|
|
)
|
|
|
|
# add to my plugins and edit hub status
|
|
plugin_entity.installed = plugin_entity.installed + 1
|
|
|
|
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
|
|
user_name, plugin_name
|
|
)
|
|
if my_plugin_entity is None:
|
|
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
|
my_plugin_entity.file_name = file_name
|
|
if user_name:
|
|
# TODO use user
|
|
my_plugin_entity.user_code = user_name
|
|
my_plugin_entity.user_name = user_name
|
|
my_plugin_entity.tenant = ""
|
|
else:
|
|
my_plugin_entity.user_code = Default_User
|
|
|
|
with self.hub_dao.session() as session:
|
|
if my_plugin_entity.id is None:
|
|
session.add(my_plugin_entity)
|
|
else:
|
|
session.merge(my_plugin_entity)
|
|
session.merge(plugin_entity)
|
|
except Exception as e:
|
|
logger.error("install pluguin exception!", e)
|
|
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
|
|
)
|
|
else:
|
|
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
|
|
|
def uninstall_plugin(self, plugin_name, user):
|
|
logger.info(f"uninstall_plugin:{plugin_name},{user}")
|
|
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
|
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
|
|
if plugin_entity is not None:
|
|
plugin_entity.installed = plugin_entity.installed - 1
|
|
with self.hub_dao.session() as session:
|
|
my_plugin_q = session.query(MyPluginEntity).filter(
|
|
MyPluginEntity.name == plugin_name
|
|
)
|
|
if user:
|
|
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
|
my_plugin_q.delete()
|
|
if plugin_entity is not None:
|
|
session.merge(plugin_entity)
|
|
|
|
if plugin_entity is not None:
|
|
# delete package file if not use
|
|
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
|
|
have_installed = False
|
|
for plugin_info in plugin_infos:
|
|
if plugin_info.installed > 0:
|
|
have_installed = True
|
|
break
|
|
if not have_installed:
|
|
plugin_repo_name = (
|
|
plugin_entity.storage_url.replace(".git", "")
|
|
.strip("/")
|
|
.split("/")[-1]
|
|
)
|
|
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
|
|
for file in files:
|
|
os.remove(file)
|
|
else:
|
|
files = glob.glob(
|
|
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
|
|
)
|
|
for file in files:
|
|
os.remove(file)
|
|
|
|
def __download_from_git(self, github_repo, branch_name, authorization):
|
|
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
|
|
|
|
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
|
|
my_plugin_entity = MyPluginEntity()
|
|
my_plugin_entity.name = hub_plugin.name
|
|
my_plugin_entity.type = hub_plugin.type
|
|
my_plugin_entity.version = hub_plugin.version
|
|
return my_plugin_entity
|
|
|
|
def refresh_hub_from_git(
|
|
self,
|
|
github_repo: str = None,
|
|
branch_name: str = "main",
|
|
authorization: str = None,
|
|
):
|
|
logger.info("refresh_hub_by_git start!")
|
|
update_from_git(
|
|
self.temp_hub_file_path, github_repo, branch_name, authorization
|
|
)
|
|
git_plugins = scan_plugins(self.temp_hub_file_path)
|
|
try:
|
|
for git_plugin in git_plugins:
|
|
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
|
|
if old_hub_info:
|
|
plugin_hub_info = old_hub_info
|
|
else:
|
|
plugin_hub_info = PluginHubEntity()
|
|
plugin_hub_info.type = ""
|
|
plugin_hub_info.storage_channel = PluginStorageType.Git.value
|
|
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
|
|
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
|
|
plugin_hub_info.email = getattr(git_plugin, "_email", "")
|
|
download_param = {}
|
|
if branch_name:
|
|
download_param["branch_name"] = branch_name
|
|
if authorization and len(authorization) > 0:
|
|
download_param["authorization"] = authorization
|
|
plugin_hub_info.download_param = json.dumps(download_param)
|
|
plugin_hub_info.installed = 0
|
|
|
|
plugin_hub_info.name = git_plugin._name
|
|
plugin_hub_info.version = git_plugin._version
|
|
plugin_hub_info.description = git_plugin._description
|
|
self.hub_dao.raw_update(plugin_hub_info)
|
|
except Exception as e:
|
|
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
|
|
|
|
def _sanitize_filename(self, filename: str) -> str:
|
|
"""Sanitize the filename to prevent directory traversal attacks.
|
|
|
|
Args:
|
|
filename: The original filename
|
|
|
|
Returns:
|
|
str: Sanitized filename
|
|
"""
|
|
# Only keep the basic filename, remove any path information
|
|
filename = os.path.basename(filename)
|
|
|
|
# Remove any unsafe characters
|
|
filename = re.sub(r"[^a-zA-Z0-9._-]", "", filename)
|
|
|
|
# Ensure the filename is not empty and valid
|
|
if not filename or filename.startswith("."):
|
|
raise ValueError("Invalid filename")
|
|
|
|
return filename
|
|
|
|
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
|
|
# Verify and clean file names
|
|
try:
|
|
safe_filename = self._sanitize_filename(doc_file.filename)
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid plugin file: {str(e)}")
|
|
|
|
# Structure a safe file path
|
|
file_path = os.path.join(self.plugin_dir, safe_filename)
|
|
|
|
# Verify the final path is within the allowed directory
|
|
if (
|
|
not Path(file_path)
|
|
.resolve()
|
|
.is_relative_to(Path(self.plugin_dir).resolve())
|
|
):
|
|
raise ValueError("Invalid file path")
|
|
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
|
|
# Use a temporary file for secure file writing
|
|
tmp_fd, tmp_path = tempfile.mkstemp(dir=self.plugin_dir)
|
|
try:
|
|
with os.fdopen(tmp_fd, "wb") as tmp:
|
|
tmp.write(await doc_file.read())
|
|
shutil.move(tmp_path, file_path)
|
|
except Exception as e:
|
|
# Ensure the temporary file is cleaned up
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
raise e
|
|
|
|
# Scan and validate the plugin
|
|
try:
|
|
my_plugins = scan_plugins(self.plugin_dir, safe_filename)
|
|
except Exception as e:
|
|
# If the plugin validation fails, clean up the uploaded file
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
raise ValueError(f"Invalid plugin file: {str(e)}")
|
|
|
|
if user is None or len(user) <= 0:
|
|
user = Default_User
|
|
|
|
# Update the database
|
|
for my_plugin in my_plugins:
|
|
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
|
|
user, my_plugin._name
|
|
)
|
|
if my_plugin_entiy is None:
|
|
my_plugin_entiy = MyPluginEntity()
|
|
my_plugin_entiy.name = my_plugin._name
|
|
my_plugin_entiy.version = my_plugin._version
|
|
my_plugin_entiy.type = "Personal"
|
|
my_plugin_entiy.user_code = user
|
|
my_plugin_entiy.user_name = user
|
|
my_plugin_entiy.tenant = ""
|
|
my_plugin_entiy.file_name = safe_filename
|
|
self.my_plugin_dao.raw_update(my_plugin_entiy)
|
|
|
|
def reload_my_plugins(self):
|
|
logger.info(f"load_plugins start!")
|
|
return scan_plugins(self.plugin_dir)
|
|
|
|
def get_my_plugin(self, user: str):
|
|
logger.info(f"get_my_plugin:{user}")
|
|
if not user:
|
|
user = Default_User
|
|
return self.my_plugin_dao.get_by_user(user)
|
|
|
|
|
|
plugin_hub = PluginHub(PLUGINS_DIR)
|