mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 13:58:58 +00:00
fix(ChatExcel): ChatExcel OutParse Bug Fix
1.ChatDashboard Display optimization
This commit is contained in:
@@ -1,13 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class AgentFacade(ABC):
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
|
||||
|
||||
|
@@ -27,7 +27,6 @@ def execute_ai_response_json(
|
||||
user_input: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
||||
Args:
|
||||
command_registry:
|
||||
ai_response:
|
||||
@@ -65,6 +64,8 @@ def execute_ai_response_json(
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments,
|
||||
|
@@ -3,6 +3,8 @@ import time
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
UploadFile,
|
||||
File,
|
||||
)
|
||||
|
||||
from typing import List
|
||||
@@ -13,14 +15,77 @@ from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
)
|
||||
|
||||
|
||||
from .model import PluginHubParam, PagenationFilter, PagenationResult
|
||||
from .hub.agent_hub import AgentHub
|
||||
from .db.plugin_hub_db import PluginHubEntity
|
||||
from .db.my_plugin_db import MyPluginEntity
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
|
||||
router = APIRouter()
|
||||
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
|
||||
|
||||
|
||||
@router.get("/v1/mange/agent/list", response_model=Result[str])
|
||||
async def get_agent_list():
|
||||
logger.info(f"get_agent_list!")
|
||||
|
||||
@router.post("/v1/agent/hub/update", response_model=Result[str])
|
||||
async def agent_hub_update(update_param: PluginHubParam = Body()):
|
||||
logger.info(f"agent_hub_update:{update_param.__dict__}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization)
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post("/v1/agent/query", response_model=Result[str])
|
||||
async def get_agent_list(filter: PagenationFilter[PluginHubEntity] = Body()):
|
||||
logger.info(f"get_agent_list:{json.dumps(filter)}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
datas, total_pages, total_count = agent_hub.hub_dao.list(filter.filter, filter.page_index, filter.page_size)
|
||||
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
|
||||
result.page_index = filter.page_index
|
||||
result.page_size = filter.page_size
|
||||
result.total_page = total_pages
|
||||
result.total_row_count = total_count
|
||||
result.datas = datas
|
||||
return Result.succ(result)
|
||||
|
||||
@router.post("/v1/agent/my", response_model=Result[str])
|
||||
async def my_agents(user:str= None):
|
||||
logger.info(f"my_agents:{json.dumps(my_agents)}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
return Result.succ(agent_hub.get_my_plugin(user))
|
||||
|
||||
|
||||
@router.post("/v1/agent/install", response_model=Result[str])
|
||||
async def agent_install(plugin_name:str, user: str = None):
|
||||
logger.info(f"agent_install:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.install_plugin(plugin_name, user)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Install Error!", e)
|
||||
return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
|
||||
|
||||
|
||||
|
||||
@router.post("/v1/agent/uninstall", response_model=Result[str])
|
||||
async def agent_uninstall(plugin_name:str, user: str = None):
|
||||
logger.info(f"agent_uninstall:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.uninstall_plugin(plugin_name, user)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Plugin Uninstall Error!", e)
|
||||
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
|
||||
|
||||
|
||||
@router.post("/v1/personal/agent/upload", response_model=Result[str])
|
||||
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
|
||||
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.upload_my_plugin(doc_file, user)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Upload Personal Plugin Error!", e)
|
||||
return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")
|
||||
|
||||
|
@@ -18,6 +18,7 @@ class MyPluginEntity(Base):
|
||||
user_code = Column(String, nullable=True, comment="user code")
|
||||
user_name = Column(String, nullable=True, comment="user name")
|
||||
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||
file_name = Column(String, nullable=False, comment="plugin package file name")
|
||||
type = Column(String, comment="plugin type")
|
||||
version = Column(String, comment="plugin version")
|
||||
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
|
||||
@@ -74,6 +75,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
|
||||
session = self.Session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
all_count = my_plugins.count()
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
@@ -101,7 +103,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
return result
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
|
||||
def count(self, query: MyPluginEntity):
|
||||
session = self.Session()
|
||||
|
@@ -13,12 +13,14 @@ class PluginHubEntity(Base):
|
||||
__tablename__ = 'plugin_hub'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||
description = Column(String, nullable=False, comment="plugin description")
|
||||
author = Column(String, nullable=True, comment="plugin author")
|
||||
email = Column(String, nullable=True, comment="plugin author email")
|
||||
type = Column(String, comment="plugin type")
|
||||
version = Column(String, comment="plugin version")
|
||||
storage_channel = Column(String, comment="plugin storage channel")
|
||||
storage_url = Column(String, comment="plugin download url")
|
||||
download_param = Column(String, comment="plugin download param")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
|
||||
installed = Column(Boolean, default=False, comment="plugin already installed")
|
||||
|
||||
@@ -61,6 +63,8 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
|
||||
session = self.Session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
all_count = plugin_hubs.count()
|
||||
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
@@ -84,6 +88,20 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
def get_by_storage_url(self, storage_url):
|
||||
session = self.Session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
|
||||
result = plugin_hubs.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||
|
@@ -1,39 +1,54 @@
|
||||
import json
|
||||
import logging
|
||||
import git
|
||||
import os
|
||||
import glob
|
||||
import zipfile
|
||||
import requests
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
import shutil
|
||||
from fastapi import UploadFile
|
||||
from typing import Any
|
||||
import tempfile
|
||||
|
||||
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
||||
from .schema import PluginStorageType
|
||||
from ..plugins_util import scan_plugins, update_from_git
|
||||
|
||||
logger = logging.getLogger("agent_hub")
|
||||
Default_User = "default"
|
||||
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
|
||||
TEMP_PLUGIN_PATH = ""
|
||||
|
||||
|
||||
class AgentHub:
|
||||
def __init__(self, temp_hub_file_path:str = "") -> None:
|
||||
def __init__(self, plugin_dir) -> None:
|
||||
self.hub_dao = PluginHubDao()
|
||||
self.my_lugin_dao = MyPluginDao()
|
||||
if temp_hub_file_path:
|
||||
self.temp_hub_file_path = temp_hub_file_path
|
||||
else:
|
||||
self.temp_hub_file_path = os.path.join(os.getcwd(), "plugins", "temp")
|
||||
os.makedirs(plugin_dir, exist_ok=True)
|
||||
self.plugin_dir = plugin_dir
|
||||
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
|
||||
|
||||
def install_plugin(self, plugin_name: str, user_name: str = None):
|
||||
logger.info(f"install_plugin {plugin_name}")
|
||||
|
||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||
if plugin_entity:
|
||||
if plugin_entity.storage_channel == PluginStorageType.Git.value:
|
||||
try:
|
||||
self.__download_from_git(plugin_name, plugin_entity.storage_url)
|
||||
self.load_plugin(plugin_name)
|
||||
branch_name = None
|
||||
authorization = None
|
||||
if plugin_entity.download_param:
|
||||
download_param = json.loads(plugin_entity.download_param)
|
||||
branch_name = download_param.get("branch_name")
|
||||
authorization = download_param.get("authorization")
|
||||
file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization)
|
||||
|
||||
# add to my plugins and edit hub status
|
||||
plugin_entity.installed = True
|
||||
|
||||
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
||||
my_plugin_entity.file_name = file_name
|
||||
if user_name:
|
||||
# TODO use user
|
||||
my_plugin_entity.user_code = ""
|
||||
@@ -55,9 +70,43 @@ class AgentHub:
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
||||
|
||||
def uninstall_plugin(self, plugin_name, user):
|
||||
logger.info(f"uninstall_plugin:{plugin_name},{user}")
|
||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||
plugin_entity.installed = False
|
||||
with self.hub_dao.Session() as session:
|
||||
try:
|
||||
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
|
||||
if user:
|
||||
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||
my_plugin_q.delete()
|
||||
session.merge(plugin_entity)
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
|
||||
# delete package file if not use
|
||||
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
|
||||
have_installed = False
|
||||
for plugin_info in plugin_infos:
|
||||
if plugin_info.installed:
|
||||
have_installed = True
|
||||
break
|
||||
if not have_installed:
|
||||
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
|
||||
files = glob.glob(
|
||||
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
|
||||
def __download_from_git(self, github_repo, branch_name, authorization):
|
||||
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
|
||||
|
||||
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
|
||||
my_plugin_entity = MyPluginEntity()
|
||||
my_plugin_entity.name = hub_plugin.name
|
||||
@@ -65,29 +114,68 @@ class AgentHub:
|
||||
my_plugin_entity.version = hub_plugin.version
|
||||
return my_plugin_entity
|
||||
|
||||
def __fetch_from_git(self):
|
||||
logger.info("fetch plugins from git to local path:{}", self.temp_hub_file_path)
|
||||
os.makedirs(self.temp_hub_file_path, exist_ok=True)
|
||||
repo = git.Repo(self.temp_hub_file_path)
|
||||
if repo.is_repo():
|
||||
repo.remotes.origin.pull()
|
||||
else:
|
||||
git.Repo.clone_from(DEFAULT_PLUGIN_REPO, self.temp_hub_file_path)
|
||||
def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None):
|
||||
logger.info("refresh_hub_by_git start!")
|
||||
update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization)
|
||||
git_plugins = scan_plugins(self.temp_hub_file_path)
|
||||
for git_plugin in git_plugins:
|
||||
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
|
||||
if old_hub_info:
|
||||
plugin_hub_info = old_hub_info
|
||||
else:
|
||||
plugin_hub_info = PluginHubEntity()
|
||||
plugin_hub_info.type = ""
|
||||
plugin_hub_info.storage_channel = PluginStorageType.Git.value
|
||||
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
|
||||
plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT')
|
||||
plugin_hub_info.email = getattr(git_plugin, '_email', '')
|
||||
plugin_hub_info.download_param = json.dumps({
|
||||
branch_name: branch_name,
|
||||
authorization: authorization
|
||||
})
|
||||
plugin_hub_info.installed = False
|
||||
|
||||
# if repo.head.is_valid():
|
||||
# clone succ, fetch plugins info
|
||||
plugin_hub_info.name = git_plugin._name
|
||||
plugin_hub_info.version = git_plugin._version
|
||||
plugin_hub_info.description = git_plugin._description
|
||||
self.hub_dao.update(plugin_hub_info)
|
||||
|
||||
def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User):
|
||||
|
||||
# We can not move temp file in windows system when we open file in context of `with`
|
||||
file_path = os.path.join(self.plugin_dir, doc_file.filename)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(
|
||||
dir=os.path.join(self.plugin_dir)
|
||||
)
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
shutil.move(
|
||||
tmp_path,
|
||||
os.path.join(self.plugin_dir, doc_file.filename),
|
||||
)
|
||||
|
||||
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
|
||||
for my_plugin in my_plugins:
|
||||
my_plugin_entiy = MyPluginEntity()
|
||||
|
||||
my_plugin_entiy.name = my_plugin._name
|
||||
my_plugin_entiy.version = my_plugin._version
|
||||
my_plugin_entiy.type = "Personal"
|
||||
my_plugin_entiy.user_code = user
|
||||
my_plugin_entiy.user_name = user
|
||||
my_plugin_entiy.tenant = ""
|
||||
my_plugin_entiy.file_name = doc_file.filename
|
||||
|
||||
self.my_lugin_dao.update(my_plugin_entiy)
|
||||
|
||||
|
||||
def upload_plugin_in_hub(self, name: str, path: str):
|
||||
|
||||
pass
|
||||
|
||||
def __download_from_git(self, plugin_name, url):
|
||||
pass
|
||||
|
||||
def load_plugin(self, plugin_name):
|
||||
logger.info(f"load_plugin:{plugin_name}")
|
||||
pass
|
||||
def reload_my_plugins(self):
|
||||
logger.info(f"load_plugins start!")
|
||||
return scan_plugins(self.plugin_dir)
|
||||
|
||||
def get_my_plugin(self, user: str):
|
||||
logger.info(f"get_my_plugin:{user}")
|
||||
@@ -95,7 +183,3 @@ class AgentHub:
|
||||
user = Default_User
|
||||
return self.my_lugin_dao.get_by_user(user)
|
||||
|
||||
def uninstall_plugin(self, plugin_name, user):
|
||||
logger.info(f"uninstall_plugin:{plugin_name},{user}")
|
||||
|
||||
pass
|
||||
|
30
pilot/base_modules/agent/model.py
Normal file
30
pilot/base_modules/agent/model.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import TypedDict, Optional, Dict, List
|
||||
from dataclasses import dataclass
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Generic, Any
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class PagenationFilter(Generic[T]):
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
filter: T = None
|
||||
|
||||
class PagenationResult(Generic[T]):
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
total_page: int = 0
|
||||
total_row_count: int = 0
|
||||
datas: List[T] = []
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginHubParam:
|
||||
channel: str = Field(..., description="Plugin storage channel")
|
||||
url: str = Field(..., description="Plugin storage url")
|
||||
|
||||
branch: str = Field(..., description="When the storage channel is github, use to specify the branch", nullable=True)
|
||||
authorization: str = Field(..., description="github download authorization", nullable=True)
|
||||
|
||||
|
0
pilot/base_modules/agent/module.py
Normal file
0
pilot/base_modules/agent/module.py
Normal file
2
pilot/base_modules/agent/plugins_loader.py
Normal file
2
pilot/base_modules/agent/plugins_loader.py
Normal file
@@ -0,0 +1,2 @@
|
||||
|
||||
class PluginLoader():
|
@@ -5,6 +5,7 @@ import os
|
||||
import glob
|
||||
import zipfile
|
||||
import requests
|
||||
import git
|
||||
import threading
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
@@ -20,7 +21,6 @@ from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.logs import logger
|
||||
|
||||
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
"""
|
||||
Loader zip plugin file. Native support Auto_gpt_plugin
|
||||
@@ -106,7 +106,7 @@ def load_native_plugins(cfg: Config):
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
print("save file")
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
cfg.set_plugins(scan_plugins(cfg.debug_mode))
|
||||
else:
|
||||
print("get file faild,response code:", response.status_code)
|
||||
except Exception as e:
|
||||
@@ -116,7 +116,30 @@ def load_native_plugins(cfg: Config):
|
||||
t.start()
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]:
|
||||
logger.info(f"__scan_plugin_file:{file_path},{debug}")
|
||||
loaded_plugins = []
|
||||
if moduleList := inspect_zip_for_modules(str(file_path), debug):
|
||||
for module in moduleList:
|
||||
plugin = Path(plugin)
|
||||
module = Path(module)
|
||||
logger.debug(f"Plugin: {plugin} Module: {module}")
|
||||
zipped_package = zipimporter(str(plugin))
|
||||
zipped_module = zipped_package.load_module(str(module.parent))
|
||||
for key in dir(zipped_module):
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
a_module = getattr(zipped_module, key)
|
||||
a_keys = dir(a_module)
|
||||
if (
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
):
|
||||
loaded_plugins.append(a_module())
|
||||
return loaded_plugins
|
||||
|
||||
def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
@@ -126,31 +149,16 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
||||
Returns:
|
||||
List[Tuple[str, Path]]: List of plugins.
|
||||
"""
|
||||
loaded_plugins = []
|
||||
current_dir = os.getcwd()
|
||||
print(current_dir)
|
||||
# Generic plugins
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
|
||||
for plugin in plugins_path_path.glob("*.zip"):
|
||||
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
||||
for module in moduleList:
|
||||
plugin = Path(plugin)
|
||||
module = Path(module)
|
||||
logger.debug(f"Plugin: {plugin} Module: {module}")
|
||||
zipped_package = zipimporter(str(plugin))
|
||||
zipped_module = zipped_package.load_module(str(module.parent))
|
||||
for key in dir(zipped_module):
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
a_module = getattr(zipped_module, key)
|
||||
a_keys = dir(a_module)
|
||||
if (
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
):
|
||||
loaded_plugins.append(a_module())
|
||||
loaded_plugins = []
|
||||
# Generic plugins
|
||||
plugins_path = Path(plugins_file_path)
|
||||
if file_name:
|
||||
plugin_path = Path(plugins_path, file_name)
|
||||
loaded_plugins = __scan_plugin_file(plugin_path)
|
||||
else:
|
||||
for plugin_path in plugins_path.glob("*.zip"):
|
||||
loaded_plugins.extend(__scan_plugin_file(plugin_path))
|
||||
|
||||
if loaded_plugins:
|
||||
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||
@@ -183,5 +191,55 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
||||
return ack.lower() == cfg.authorise_key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(inspect_zip_for_modules("/Users/tuyang.yhj/Downloads/DB-GPT-Plugins-main (1).zip"))
|
||||
def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main",
|
||||
authorization: str = "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"):
|
||||
os.makedirs(download_path, exist_ok=True)
|
||||
if github_repo:
|
||||
if github_repo.index("github.com") <= 0:
|
||||
raise ValueError("Not a correct Github repository address!" + github_repo)
|
||||
github_repo = github_repo.replace(".git", "")
|
||||
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
|
||||
plugin_repo_name = github_repo.strip('/').split('/')[-1]
|
||||
else:
|
||||
url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
|
||||
plugin_repo_name = "DB-GPT-Plugins"
|
||||
try:
|
||||
session = requests.Session()
|
||||
response = session.get(
|
||||
url,
|
||||
headers={"Authorization": authorization},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(download_path)
|
||||
files = glob.glob(
|
||||
os.path.join(plugins_path_path, f"{plugin_repo_name}*")
|
||||
)
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
return plugin_repo_name
|
||||
else:
|
||||
logger.error("update plugins faild,response code:", response.status_code)
|
||||
raise ValueError("download plugin faild!" + response.status_code)
|
||||
except Exception as e:
|
||||
logger.error("update plugins from git exception!" + str(e))
|
||||
raise ValueError("download plugin exception!", e)
|
||||
|
||||
|
||||
def __fetch_from_git(local_path, git_url):
|
||||
logger.info("fetch plugins from git to local path:{}", local_path)
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
repo = git.Repo(local_path)
|
||||
if repo.is_repo():
|
||||
repo.remotes.origin.pull()
|
||||
else:
|
||||
git.Repo.clone_from(git_url, local_path)
|
||||
|
||||
# if repo.head.is_valid():
|
||||
# clone succ, fetch plugins info
|
@@ -56,6 +56,13 @@ class ChatScene(Enum):
|
||||
param_types=["Plugin Select"],
|
||||
)
|
||||
|
||||
ChatAgent = Scene(
|
||||
code="chat_agent",
|
||||
name="Agent Chat",
|
||||
describe="Use tools through dialogue to accomplish your goals.",
|
||||
param_types=["Plugin Select"],
|
||||
)
|
||||
|
||||
InnerChatDBSummary = Scene(
|
||||
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
|
||||
)
|
||||
|
0
pilot/scene/chat_agent/__init__.py
Normal file
0
pilot/scene/chat_agent/__init__.py
Normal file
72
pilot/scene/chat_agent/chat.py
Normal file
72
pilot/scene/chat_agent/chat.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
from pilot.base_modules.agent.commands.command import execute_command
|
||||
from pilot.base_modules.agent import PluginPromptGenerator
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ChatWithPlugin(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatAgent.value()
|
||||
plugins_prompt_generator: PluginPromptGenerator
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
self.plugin_selector = chat_param.select_param
|
||||
chat_param["chat_mode"] = ChatScene.ChatAgent
|
||||
super().__init__(chat_param=chat_param)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
# 加载插件中可用命令
|
||||
self.select_plugin = self.plugin_selector
|
||||
if self.select_plugin:
|
||||
for plugin in CFG.plugins:
|
||||
if plugin._name == self.plugin_selector:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
self.plugins_prompt_generator = plugin.post_prompt(
|
||||
self.plugins_prompt_generator
|
||||
)
|
||||
|
||||
else:
|
||||
for plugin in CFG.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
self.plugins_prompt_generator = plugin.post_prompt(
|
||||
self.plugins_prompt_generator
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
"input": self.current_user_input,
|
||||
"constraints": self.__list_to_prompt_str(
|
||||
list(self.plugins_prompt_generator.constraints)
|
||||
),
|
||||
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
|
||||
}
|
||||
return input_values
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
print(f"do_action:{prompt_response}")
|
||||
## plugin command run
|
||||
return execute_command(
|
||||
str(prompt_response.command.get("name")),
|
||||
prompt_response.command.get("args", {}),
|
||||
self.plugins_prompt_generator,
|
||||
)
|
||||
|
||||
def chat_show(self):
|
||||
super().chat_show()
|
||||
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||
|
||||
def generate(self, p) -> str:
|
||||
return super().generate(p)
|
||||
|
||||
@property
|
||||
def chat_type(self) -> str:
|
||||
return ChatScene.ChatAgent.value
|
23
pilot/scene/chat_agent/example.py
Normal file
23
pilot/scene/chat_agent/example.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pilot.prompts.example_base import ExampleSelector
|
||||
|
||||
## Two examples are defined by default
|
||||
EXAMPLES = [
|
||||
{
|
||||
"messages": [
|
||||
{"type": "human", "data": {"content": "查询xxx", "example": True}},
|
||||
{
|
||||
"type": "ai",
|
||||
"data": {
|
||||
"content": """{
|
||||
\"thoughts\": \"thought text\",
|
||||
\"speak\": \"thoughts summary to say to user\",
|
||||
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
||||
}""",
|
||||
"example": True,
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
46
pilot/scene/chat_agent/out_parser.py
Normal file
46
pilot/scene/chat_agent/out_parser.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import json
|
||||
from typing import Dict, NamedTuple
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
|
||||
class PluginAction(NamedTuple):
|
||||
command: Dict
|
||||
speak: str = ""
|
||||
thoughts: str = ""
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
clean_json_str = super().parse_prompt_response(model_out_text)
|
||||
print(clean_json_str)
|
||||
if not clean_json_str:
|
||||
raise ValueError("model server response not have json!")
|
||||
try:
|
||||
response = json.loads(clean_json_str)
|
||||
except Exception as e:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
|
||||
speak = ""
|
||||
thoughts = ""
|
||||
for key in sorted(response):
|
||||
if key.strip() == "command":
|
||||
command = response[key]
|
||||
if key.strip() == "thoughts":
|
||||
thoughts = response[key]
|
||||
if key.strip() == "speak":
|
||||
speak = response[key]
|
||||
return PluginAction(command, speak, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
print(f"parse_view_response:{speak},{str(data)}")
|
||||
view_text = f"##### {speak}" + "\n" + str(data)
|
||||
return view_text
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
78
pilot/scene/chat_agent/prompt.py
Normal file
78
pilot/scene/chat_agent/prompt.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle, ExampleType
|
||||
|
||||
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||
from pilot.scene.chat_execution.example import plugin_example
|
||||
|
||||
CFG = Config()
|
||||
|
||||
_PROMPT_SCENE_DEFINE_EN = "You are a universal AI assistant."
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
You need to analyze the user goals and, under the given constraints, prioritize using one of the following tools to solve the user goals.
|
||||
Tool list:
|
||||
{tool_list}
|
||||
Constraint:
|
||||
1. After selecting an available tool, please ensure that the output results include the following parts to use the tool:
|
||||
<api-call><name>Selected Tool name</name> <arg1>Parameter value</arg1><arg2 >Parameter value 2</arg2></api-call>
|
||||
2. If you cannot analyze the exact tool for the problem, you can consider using the search engine tool among the tools first.
|
||||
3. Parameter content may need to be inferred based on the user's goals, not just extracted from text
|
||||
4. If you cannot find a suitable tool, please answer Unable to complete the goal.
|
||||
{expand_constraints}
|
||||
User goals:
|
||||
{user_goal}
|
||||
"""
|
||||
|
||||
_PROMPT_SCENE_DEFINE_ZH = "你是一个通用AI助手!"
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
请一步步思考,如何在满足下面约束条件的前提下,回答或解决用户问题或目标。
|
||||
工具列表:
|
||||
{tool_list}
|
||||
约束条件:
|
||||
1. 找到可用的工具后,请确保输出结果包含以下内容用来使用工具:<api-call><name>Selected Tool name</name> <arg1>Parameter value</arg1><arg2 >Parameter value 2</arg2></api-call>
|
||||
2.任务重可以使用多个工具,上面约束的方式生成每个工具的调用,对于工具使用的提示文本,需要在工具使用前生成
|
||||
3.如果有多个工具被使用,后续工具需要第一个工具的结果作为参数的, 使用如下文本来替代参数值:<api1-result>
|
||||
4.如果对于问题无法理解和解决,可以考虑优先使用工具中的搜索引擎工具
|
||||
5.参数内容可能需要根据用户的目标推理得到,不仅仅是从文本提取
|
||||
6.如果中无法找到合适的工具,请回答无法完成目标。
|
||||
7.约束条件和工具信息作为推理过程的辅助信息,不要表达在给用户的输出内容中
|
||||
{expand_constraints}
|
||||
用户目标:
|
||||
{user_goal}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
_PROMPT_SCENE_DEFINE=(
|
||||
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||
)
|
||||
|
||||
RESPONSE_FORMAT = None
|
||||
|
||||
|
||||
EXAMPLE_TYPE = ExampleType.ONE_SHOT
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
### Whether the model service is streaming output
|
||||
PROMPT_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatAgent.value(),
|
||||
input_variables=["tool_list", "expand_constraints", "user_goal"],
|
||||
response_format=None,
|
||||
template_define=_PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=PluginChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
||||
),
|
||||
# example_selector=plugin_example,
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@@ -1 +0,0 @@
|
||||
|
@@ -6,6 +6,7 @@ from typing import Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.componet import SystemApp
|
||||
from pilot.utils.parameter_utils import BaseParameters
|
||||
from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade
|
||||
@@ -30,7 +31,7 @@ def async_db_summery(system_app: SystemApp):
|
||||
def server_init(args, system_app: SystemApp):
|
||||
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
|
||||
|
||||
from pilot.base_modules.agent.plugins import scan_plugins
|
||||
from pilot.base_modules.agent.plugins_util import scan_plugins
|
||||
|
||||
# logger.info(f"args: {args}")
|
||||
|
||||
@@ -43,7 +44,7 @@ def server_init(args, system_app: SystemApp):
|
||||
# load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode))
|
||||
|
||||
# Loader plugins and commands
|
||||
command_categories = [
|
||||
@@ -126,3 +127,4 @@ class WebWerverParameters(BaseParameters):
|
||||
},
|
||||
)
|
||||
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
|
||||
|
||||
|
Reference in New Issue
Block a user