feat(agent): Multi agent sdk (#976)

Co-authored-by: xtyuns <xtyuns@163.com>
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
Co-authored-by: qidanrui <qidanrui@gmail.com>
This commit is contained in:
明天
2023-12-27 16:25:55 +08:00
committed by GitHub
parent 69fb97e508
commit 9aec636b02
79 changed files with 6359 additions and 121 deletions

View File

View File

@@ -0,0 +1,208 @@
import json
import logging
import os
import glob
import shutil
from fastapi import UploadFile
from typing import Any
import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from dbgpt.agent.common.schema import PluginStorageType
from dbgpt.agent.plugin.plugins_util import scan_plugins, update_from_git
logger = logging.getLogger(__name__)
Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = ""
class AgentHub:
def __init__(self, plugin_dir) -> None:
self.hub_dao = PluginHubDao()
self.my_plugin_dao = MyPluginDao()
os.makedirs(plugin_dir, exist_ok=True)
self.plugin_dir = plugin_dir
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
def install_plugin(self, plugin_name: str, user_name: str = None):
logger.info(f"install_plugin {plugin_name}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
if plugin_entity:
if plugin_entity.storage_channel == PluginStorageType.Git.value:
try:
branch_name = None
authorization = None
if plugin_entity.download_param:
download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization")
file_name = self.__download_from_git(
plugin_entity.storage_url, branch_name, authorization
)
# add to my plugins and edit hub status
plugin_entity.installed = plugin_entity.installed + 1
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
user_name, plugin_name
)
if my_plugin_entity is None:
my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name
if user_name:
# TODO use user
my_plugin_entity.user_code = user_name
my_plugin_entity.user_name = user_name
my_plugin_entity.tenant = ""
else:
my_plugin_entity.user_code = Default_User
with self.hub_dao.session() as session:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
except Exception as e:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
else:
raise ValueError(
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
)
else:
raise ValueError(f"Can't Find Plugin {plugin_name}!")
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.session() as session:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)
if plugin_entity is not None:
# delete package file if not use
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
have_installed = False
for plugin_info in plugin_infos:
if plugin_info.installed > 0:
have_installed = True
break
if not have_installed:
plugin_repo_name = (
plugin_entity.storage_url.replace(".git", "")
.strip("/")
.split("/")[-1]
)
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
for file in files:
os.remove(file)
else:
files = glob.glob(
os.path.join(self.plugin_dir, f"{my_plugin_entity.file_name}")
)
for file in files:
os.remove(file)
def __download_from_git(self, github_repo, branch_name, authorization):
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
my_plugin_entity = MyPluginEntity()
my_plugin_entity.name = hub_plugin.name
my_plugin_entity.type = hub_plugin.type
my_plugin_entity.version = hub_plugin.version
return my_plugin_entity
def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = "main",
authorization: str = None,
):
logger.info("refresh_hub_by_git start!")
update_from_git(
self.temp_hub_file_path, github_repo, branch_name, authorization
)
git_plugins = scan_plugins(self.temp_hub_file_path)
try:
for git_plugin in git_plugins:
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
if old_hub_info:
plugin_hub_info = old_hub_info
else:
plugin_hub_info = PluginHubEntity()
plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
plugin_hub_info.email = getattr(git_plugin, "_email", "")
download_param = {}
if branch_name:
download_param["branch_name"] = branch_name
if authorization and len(authorization) > 0:
download_param["authorization"] = authorization
plugin_hub_info.download_param = json.dumps(download_param)
plugin_hub_info.installed = 0
plugin_hub_info.name = git_plugin._name
plugin_hub_info.version = git_plugin._version
plugin_hub_info.description = git_plugin._description
self.hub_dao.raw_update(plugin_hub_info)
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
# We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path):
os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
tmp_path,
os.path.join(self.plugin_dir, doc_file.filename),
)
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
if user is None or len(user) <= 0:
user = Default_User
for my_plugin in my_plugins:
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
user, my_plugin._name
)
if my_plugin_entiy is None:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal"
my_plugin_entiy.user_code = user
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_plugin_dao.raw_update(my_plugin_entiy)
def reload_my_plugins(self):
logger.info(f"load_plugins start!")
return scan_plugins(self.plugin_dir)
def get_my_plugin(self, user: str):
logger.info(f"get_my_plugin:{user}")
if not user:
user = Default_User
return self.my_plugin_dao.get_by_user(user)

View File

@@ -0,0 +1,160 @@
import logging
from fastapi import (
APIRouter,
Body,
UploadFile,
File,
)
from abc import ABC
from typing import List
from dbgpt.app.openapi.api_view_model import (
Result,
)
from dbgpt.serve.agent.model import (
PluginHubParam,
PagenationFilter,
PagenationResult,
PluginHubFilter,
)
from dbgpt.serve.agent.hub.agent_hub import AgentHub
from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity
from dbgpt.agent.plugin.plugins_util import scan_plugins
from dbgpt.agent.plugin.generator import PluginPromptGenerator
from dbgpt.configs.model_config import PLUGINS_DIR
from dbgpt.component import BaseComponent, ComponentType, SystemApp
router = APIRouter()
logger = logging.getLogger(__name__)
class ModuleAgent(BaseComponent, ABC):
name = ComponentType.AGENT_HUB
def __init__(self):
# load plugins
self.plugins = scan_plugins(PLUGINS_DIR)
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
def refresh_plugins(self):
self.plugins = scan_plugins(PLUGINS_DIR)
def load_select_plugin(
self, generator: PluginPromptGenerator, select_plugins: List[str]
) -> PluginPromptGenerator:
logger.info(f"load_select_plugin:{select_plugins}")
# load select plugin
for plugin in self.plugins:
if plugin._name in select_plugins:
if not plugin.can_handle_post_prompt():
continue
generator = plugin.post_prompt(generator)
return generator
module_agent = ModuleAgent()
@router.post("/v1/agent/hub/update", response_model=Result[str])
async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
branch = (
update_param.branch
if update_param.branch is not None and len(update_param.branch) > 0
else "main"
)
authorization = (
update_param.authorization
if update_param.branch is not None and len(update_param.branch) > 0
else None
)
# TODO change it to async
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
return Result.succ(None)
except Exception as e:
logger.error("Agent Hub Update Error!", e)
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
logger.info(f"get_agent_list:{filter.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
filter_enetity: PluginHubEntity = PluginHubEntity()
if filter.filter:
attrs = vars(filter.filter) # 获取原始对象的属性字典
for attr, value in attrs.items():
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
datas, total_pages, total_count = agent_hub.hub_dao.list(
filter_enetity, filter.page_index, filter.page_size
)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result.page_index = filter.page_index
result.page_size = filter.page_size
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
# print(json.dumps(result.to_dic()))
return Result.succ(result.to_dic())
@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user: str = None):
logger.info(f"my_agents:{user}")
agent_hub = AgentHub(PLUGINS_DIR)
agents = agent_hub.get_my_plugin(user)
agent_dicts = []
for agent in agents:
agent_dicts.append(agent.__dict__)
return Result.succ(agent_dicts)
@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name: str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.install_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.failed(code="E0021", msg=f"Plugin Install Error {e}")
@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name: str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.uninstall_plugin(plugin_name, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.failed(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
await agent_hub.upload_my_plugin(doc_file, user)
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.failed(code="E0023", msg=f"Upload Personal Plugin Error {e}")