mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 13:58:58 +00:00
fix(ChatExcel): ChatExcel OutParse Bug Fix
1.ChatDashboard Display optimization
This commit is contained in:
@@ -1,13 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AgentFacade(ABC):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.model = None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@@ -27,7 +27,6 @@ def execute_ai_response_json(
|
|||||||
user_input: str = None,
|
user_input: str = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
command_registry:
|
command_registry:
|
||||||
ai_response:
|
ai_response:
|
||||||
@@ -65,6 +64,8 @@ def execute_ai_response_json(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def execute_command(
|
def execute_command(
|
||||||
command_name: str,
|
command_name: str,
|
||||||
arguments,
|
arguments,
|
||||||
|
@@ -3,6 +3,8 @@ import time
|
|||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Body,
|
Body,
|
||||||
|
UploadFile,
|
||||||
|
File,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -13,14 +15,77 @@ from pilot.openapi.api_view_model import (
|
|||||||
Result,
|
Result,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .model import PluginHubParam, PagenationFilter, PagenationResult
|
||||||
|
from .hub.agent_hub import AgentHub
|
||||||
|
from .db.plugin_hub_db import PluginHubEntity
|
||||||
|
from .db.my_plugin_db import MyPluginEntity
|
||||||
|
from pilot.configs.model_config import PLUGINS_DIR
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
|
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/mange/agent/list", response_model=Result[str])
|
@router.post("/v1/agent/hub/update", response_model=Result[str])
|
||||||
async def get_agent_list():
|
async def agent_hub_update(update_param: PluginHubParam = Body()):
|
||||||
logger.info(f"get_agent_list!")
|
logger.info(f"agent_hub_update:{update_param.__dict__}")
|
||||||
|
agent_hub = AgentHub(PLUGINS_DIR)
|
||||||
|
agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization)
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/agent/query", response_model=Result[str])
|
||||||
|
async def get_agent_list(filter: PagenationFilter[PluginHubEntity] = Body()):
|
||||||
|
logger.info(f"get_agent_list:{json.dumps(filter)}")
|
||||||
|
agent_hub = AgentHub(PLUGINS_DIR)
|
||||||
|
datas, total_pages, total_count = agent_hub.hub_dao.list(filter.filter, filter.page_index, filter.page_size)
|
||||||
|
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
|
||||||
|
result.page_index = filter.page_index
|
||||||
|
result.page_size = filter.page_size
|
||||||
|
result.total_page = total_pages
|
||||||
|
result.total_row_count = total_count
|
||||||
|
result.datas = datas
|
||||||
|
return Result.succ(result)
|
||||||
|
|
||||||
|
@router.post("/v1/agent/my", response_model=Result[str])
|
||||||
|
async def my_agents(user:str= None):
|
||||||
|
logger.info(f"my_agents:{json.dumps(my_agents)}")
|
||||||
|
agent_hub = AgentHub(PLUGINS_DIR)
|
||||||
|
return Result.succ(agent_hub.get_my_plugin(user))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/agent/install", response_model=Result[str])
|
||||||
|
async def agent_install(plugin_name:str, user: str = None):
|
||||||
|
logger.info(f"agent_install:{plugin_name},{user}")
|
||||||
|
try:
|
||||||
|
agent_hub = AgentHub(PLUGINS_DIR)
|
||||||
|
agent_hub.install_plugin(plugin_name, user)
|
||||||
|
return Result.succ(None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Plugin Install Error!", e)
|
||||||
|
return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/agent/uninstall", response_model=Result[str])
|
||||||
|
async def agent_uninstall(plugin_name:str, user: str = None):
|
||||||
|
logger.info(f"agent_uninstall:{plugin_name},{user}")
|
||||||
|
try:
|
||||||
|
agent_hub = AgentHub(PLUGINS_DIR)
|
||||||
|
agent_hub.uninstall_plugin(plugin_name, user)
|
||||||
|
return Result.succ(None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Plugin Uninstall Error!", e)
|
||||||
|
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/personal/agent/upload", response_model=Result[str])
|
||||||
|
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
|
||||||
|
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
|
||||||
|
try:
|
||||||
|
agent_hub = AgentHub(PLUGINS_DIR)
|
||||||
|
agent_hub.upload_my_plugin(doc_file, user)
|
||||||
|
return Result.succ(None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Upload Personal Plugin Error!", e)
|
||||||
|
return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")
|
||||||
|
|
||||||
|
@@ -18,6 +18,7 @@ class MyPluginEntity(Base):
|
|||||||
user_code = Column(String, nullable=True, comment="user code")
|
user_code = Column(String, nullable=True, comment="user code")
|
||||||
user_name = Column(String, nullable=True, comment="user name")
|
user_name = Column(String, nullable=True, comment="user name")
|
||||||
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||||
|
file_name = Column(String, nullable=False, comment="plugin package file name")
|
||||||
type = Column(String, comment="plugin type")
|
type = Column(String, comment="plugin type")
|
||||||
version = Column(String, comment="plugin version")
|
version = Column(String, comment="plugin version")
|
||||||
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
|
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
|
||||||
@@ -74,6 +75,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
|
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
my_plugins = session.query(MyPluginEntity)
|
my_plugins = session.query(MyPluginEntity)
|
||||||
|
all_count = my_plugins.count()
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||||
if query.name is not None:
|
if query.name is not None:
|
||||||
@@ -101,7 +103,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
|||||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
|
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
|
||||||
result = my_plugins.all()
|
result = my_plugins.all()
|
||||||
session.close()
|
session.close()
|
||||||
return result
|
total_pages = all_count // page_size
|
||||||
|
if all_count % page_size != 0:
|
||||||
|
total_pages += 1
|
||||||
|
|
||||||
|
|
||||||
|
return result, total_pages, all_count
|
||||||
|
|
||||||
|
|
||||||
def count(self, query: MyPluginEntity):
|
def count(self, query: MyPluginEntity):
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
|
@@ -13,12 +13,14 @@ class PluginHubEntity(Base):
|
|||||||
__tablename__ = 'plugin_hub'
|
__tablename__ = 'plugin_hub'
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||||
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
name = Column(String, unique=True, nullable=False, comment="plugin name")
|
||||||
|
description = Column(String, nullable=False, comment="plugin description")
|
||||||
author = Column(String, nullable=True, comment="plugin author")
|
author = Column(String, nullable=True, comment="plugin author")
|
||||||
email = Column(String, nullable=True, comment="plugin author email")
|
email = Column(String, nullable=True, comment="plugin author email")
|
||||||
type = Column(String, comment="plugin type")
|
type = Column(String, comment="plugin type")
|
||||||
version = Column(String, comment="plugin version")
|
version = Column(String, comment="plugin version")
|
||||||
storage_channel = Column(String, comment="plugin storage channel")
|
storage_channel = Column(String, comment="plugin storage channel")
|
||||||
storage_url = Column(String, comment="plugin download url")
|
storage_url = Column(String, comment="plugin download url")
|
||||||
|
download_param = Column(String, comment="plugin download param")
|
||||||
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
|
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
|
||||||
installed = Column(Boolean, default=False, comment="plugin already installed")
|
installed = Column(Boolean, default=False, comment="plugin already installed")
|
||||||
|
|
||||||
@@ -61,6 +63,8 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
|||||||
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
|
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
plugin_hubs = session.query(PluginHubEntity)
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
|
all_count = plugin_hubs.count()
|
||||||
|
|
||||||
if query.id is not None:
|
if query.id is not None:
|
||||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||||
if query.name is not None:
|
if query.name is not None:
|
||||||
@@ -84,6 +88,20 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
|||||||
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
|
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
|
||||||
result = plugin_hubs.all()
|
result = plugin_hubs.all()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
total_pages = all_count // page_size
|
||||||
|
if all_count % page_size != 0:
|
||||||
|
total_pages += 1
|
||||||
|
|
||||||
|
|
||||||
|
return result, total_pages, all_count
|
||||||
|
|
||||||
|
def get_by_storage_url(self, storage_url):
|
||||||
|
session = self.Session()
|
||||||
|
plugin_hubs = session.query(PluginHubEntity)
|
||||||
|
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
|
||||||
|
result = plugin_hubs.all()
|
||||||
|
session.close()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||||
|
@@ -1,39 +1,54 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import git
|
|
||||||
import os
|
import os
|
||||||
|
import glob
|
||||||
|
import zipfile
|
||||||
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
|
import datetime
|
||||||
|
import shutil
|
||||||
|
from fastapi import UploadFile
|
||||||
|
from typing import Any
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||||
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
|
||||||
from .schema import PluginStorageType
|
from .schema import PluginStorageType
|
||||||
|
from ..plugins_util import scan_plugins, update_from_git
|
||||||
|
|
||||||
logger = logging.getLogger("agent_hub")
|
logger = logging.getLogger("agent_hub")
|
||||||
Default_User = "default"
|
Default_User = "default"
|
||||||
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
|
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
|
||||||
TEMP_PLUGIN_PATH = ""
|
TEMP_PLUGIN_PATH = ""
|
||||||
|
|
||||||
|
|
||||||
class AgentHub:
|
class AgentHub:
|
||||||
def __init__(self, temp_hub_file_path:str = "") -> None:
|
def __init__(self, plugin_dir) -> None:
|
||||||
self.hub_dao = PluginHubDao()
|
self.hub_dao = PluginHubDao()
|
||||||
self.my_lugin_dao = MyPluginDao()
|
self.my_lugin_dao = MyPluginDao()
|
||||||
if temp_hub_file_path:
|
os.makedirs(plugin_dir, exist_ok=True)
|
||||||
self.temp_hub_file_path = temp_hub_file_path
|
self.plugin_dir = plugin_dir
|
||||||
else:
|
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
|
||||||
self.temp_hub_file_path = os.path.join(os.getcwd(), "plugins", "temp")
|
|
||||||
|
|
||||||
def install_plugin(self, plugin_name: str, user_name: str = None):
|
def install_plugin(self, plugin_name: str, user_name: str = None):
|
||||||
logger.info(f"install_plugin {plugin_name}")
|
logger.info(f"install_plugin {plugin_name}")
|
||||||
|
|
||||||
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||||
if plugin_entity:
|
if plugin_entity:
|
||||||
if plugin_entity.storage_channel == PluginStorageType.Git.value:
|
if plugin_entity.storage_channel == PluginStorageType.Git.value:
|
||||||
try:
|
try:
|
||||||
self.__download_from_git(plugin_name, plugin_entity.storage_url)
|
branch_name = None
|
||||||
self.load_plugin(plugin_name)
|
authorization = None
|
||||||
|
if plugin_entity.download_param:
|
||||||
|
download_param = json.loads(plugin_entity.download_param)
|
||||||
|
branch_name = download_param.get("branch_name")
|
||||||
|
authorization = download_param.get("authorization")
|
||||||
|
file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization)
|
||||||
|
|
||||||
# add to my plugins and edit hub status
|
# add to my plugins and edit hub status
|
||||||
plugin_entity.installed = True
|
plugin_entity.installed = True
|
||||||
|
|
||||||
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
my_plugin_entity = self.__build_my_plugin(plugin_entity)
|
||||||
|
my_plugin_entity.file_name = file_name
|
||||||
if user_name:
|
if user_name:
|
||||||
# TODO use user
|
# TODO use user
|
||||||
my_plugin_entity.user_code = ""
|
my_plugin_entity.user_code = ""
|
||||||
@@ -55,9 +70,43 @@ class AgentHub:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
|
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
||||||
|
|
||||||
|
def uninstall_plugin(self, plugin_name, user):
|
||||||
|
logger.info(f"uninstall_plugin:{plugin_name},{user}")
|
||||||
|
plugin_entity = self.hub_dao.get_by_name(plugin_name)
|
||||||
|
plugin_entity.installed = False
|
||||||
|
with self.hub_dao.Session() as session:
|
||||||
|
try:
|
||||||
|
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
|
||||||
|
if user:
|
||||||
|
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||||
|
my_plugin_q.delete()
|
||||||
|
session.merge(plugin_entity)
|
||||||
|
session.commit()
|
||||||
|
except:
|
||||||
|
session.rollback()
|
||||||
|
|
||||||
|
# delete package file if not use
|
||||||
|
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
|
||||||
|
have_installed = False
|
||||||
|
for plugin_info in plugin_infos:
|
||||||
|
if plugin_info.installed:
|
||||||
|
have_installed = True
|
||||||
|
break
|
||||||
|
if not have_installed:
|
||||||
|
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
|
||||||
|
files = glob.glob(
|
||||||
|
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
|
||||||
|
)
|
||||||
|
for file in files:
|
||||||
|
os.remove(file)
|
||||||
|
|
||||||
|
def __download_from_git(self, github_repo, branch_name, authorization):
|
||||||
|
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
|
||||||
|
|
||||||
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
|
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
|
||||||
my_plugin_entity = MyPluginEntity()
|
my_plugin_entity = MyPluginEntity()
|
||||||
my_plugin_entity.name = hub_plugin.name
|
my_plugin_entity.name = hub_plugin.name
|
||||||
@@ -65,29 +114,68 @@ class AgentHub:
|
|||||||
my_plugin_entity.version = hub_plugin.version
|
my_plugin_entity.version = hub_plugin.version
|
||||||
return my_plugin_entity
|
return my_plugin_entity
|
||||||
|
|
||||||
def __fetch_from_git(self):
|
def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None):
|
||||||
logger.info("fetch plugins from git to local path:{}", self.temp_hub_file_path)
|
logger.info("refresh_hub_by_git start!")
|
||||||
os.makedirs(self.temp_hub_file_path, exist_ok=True)
|
update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization)
|
||||||
repo = git.Repo(self.temp_hub_file_path)
|
git_plugins = scan_plugins(self.temp_hub_file_path)
|
||||||
if repo.is_repo():
|
for git_plugin in git_plugins:
|
||||||
repo.remotes.origin.pull()
|
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
|
||||||
else:
|
if old_hub_info:
|
||||||
git.Repo.clone_from(DEFAULT_PLUGIN_REPO, self.temp_hub_file_path)
|
plugin_hub_info = old_hub_info
|
||||||
|
else:
|
||||||
|
plugin_hub_info = PluginHubEntity()
|
||||||
|
plugin_hub_info.type = ""
|
||||||
|
plugin_hub_info.storage_channel = PluginStorageType.Git.value
|
||||||
|
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
|
||||||
|
plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT')
|
||||||
|
plugin_hub_info.email = getattr(git_plugin, '_email', '')
|
||||||
|
plugin_hub_info.download_param = json.dumps({
|
||||||
|
branch_name: branch_name,
|
||||||
|
authorization: authorization
|
||||||
|
})
|
||||||
|
plugin_hub_info.installed = False
|
||||||
|
|
||||||
# if repo.head.is_valid():
|
plugin_hub_info.name = git_plugin._name
|
||||||
# clone succ, fetch plugins info
|
plugin_hub_info.version = git_plugin._version
|
||||||
|
plugin_hub_info.description = git_plugin._description
|
||||||
|
self.hub_dao.update(plugin_hub_info)
|
||||||
|
|
||||||
|
def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User):
|
||||||
|
|
||||||
|
# We can not move temp file in windows system when we open file in context of `with`
|
||||||
|
file_path = os.path.join(self.plugin_dir, doc_file.filename)
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
tmp_fd, tmp_path = tempfile.mkstemp(
|
||||||
|
dir=os.path.join(self.plugin_dir)
|
||||||
|
)
|
||||||
|
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||||
|
tmp.write(await doc_file.read())
|
||||||
|
shutil.move(
|
||||||
|
tmp_path,
|
||||||
|
os.path.join(self.plugin_dir, doc_file.filename),
|
||||||
|
)
|
||||||
|
|
||||||
|
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
|
||||||
|
for my_plugin in my_plugins:
|
||||||
|
my_plugin_entiy = MyPluginEntity()
|
||||||
|
|
||||||
|
my_plugin_entiy.name = my_plugin._name
|
||||||
|
my_plugin_entiy.version = my_plugin._version
|
||||||
|
my_plugin_entiy.type = "Personal"
|
||||||
|
my_plugin_entiy.user_code = user
|
||||||
|
my_plugin_entiy.user_name = user
|
||||||
|
my_plugin_entiy.tenant = ""
|
||||||
|
my_plugin_entiy.file_name = doc_file.filename
|
||||||
|
|
||||||
|
self.my_lugin_dao.update(my_plugin_entiy)
|
||||||
|
|
||||||
|
|
||||||
def upload_plugin_in_hub(self, name: str, path: str):
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __download_from_git(self, plugin_name, url):
|
def reload_my_plugins(self):
|
||||||
pass
|
logger.info(f"load_plugins start!")
|
||||||
|
return scan_plugins(self.plugin_dir)
|
||||||
def load_plugin(self, plugin_name):
|
|
||||||
logger.info(f"load_plugin:{plugin_name}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_my_plugin(self, user: str):
|
def get_my_plugin(self, user: str):
|
||||||
logger.info(f"get_my_plugin:{user}")
|
logger.info(f"get_my_plugin:{user}")
|
||||||
@@ -95,7 +183,3 @@ class AgentHub:
|
|||||||
user = Default_User
|
user = Default_User
|
||||||
return self.my_lugin_dao.get_by_user(user)
|
return self.my_lugin_dao.get_by_user(user)
|
||||||
|
|
||||||
def uninstall_plugin(self, plugin_name, user):
|
|
||||||
logger.info(f"uninstall_plugin:{plugin_name},{user}")
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
30
pilot/base_modules/agent/model.py
Normal file
30
pilot/base_modules/agent/model.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from typing import TypedDict, Optional, Dict, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import TypeVar, Generic, Any
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
class PagenationFilter(Generic[T]):
|
||||||
|
page_index: int = 1
|
||||||
|
page_size: int = 20
|
||||||
|
filter: T = None
|
||||||
|
|
||||||
|
class PagenationResult(Generic[T]):
|
||||||
|
page_index: int = 1
|
||||||
|
page_size: int = 20
|
||||||
|
total_page: int = 0
|
||||||
|
total_row_count: int = 0
|
||||||
|
datas: List[T] = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PluginHubParam:
|
||||||
|
channel: str = Field(..., description="Plugin storage channel")
|
||||||
|
url: str = Field(..., description="Plugin storage url")
|
||||||
|
|
||||||
|
branch: str = Field(..., description="When the storage channel is github, use to specify the branch", nullable=True)
|
||||||
|
authorization: str = Field(..., description="github download authorization", nullable=True)
|
||||||
|
|
||||||
|
|
0
pilot/base_modules/agent/module.py
Normal file
0
pilot/base_modules/agent/module.py
Normal file
2
pilot/base_modules/agent/plugins_loader.py
Normal file
2
pilot/base_modules/agent/plugins_loader.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
|
||||||
|
class PluginLoader():
|
@@ -5,6 +5,7 @@ import os
|
|||||||
import glob
|
import glob
|
||||||
import zipfile
|
import zipfile
|
||||||
import requests
|
import requests
|
||||||
|
import git
|
||||||
import threading
|
import threading
|
||||||
import datetime
|
import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -20,7 +21,6 @@ from pilot.configs.model_config import PLUGINS_DIR
|
|||||||
from pilot.logs import logger
|
from pilot.logs import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Loader zip plugin file. Native support Auto_gpt_plugin
|
Loader zip plugin file. Native support Auto_gpt_plugin
|
||||||
@@ -106,7 +106,7 @@ def load_native_plugins(cfg: Config):
|
|||||||
with open(file_name, "wb") as f:
|
with open(file_name, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
print("save file")
|
print("save file")
|
||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
cfg.set_plugins(scan_plugins(cfg.debug_mode))
|
||||||
else:
|
else:
|
||||||
print("get file faild,response code:", response.status_code)
|
print("get file faild,response code:", response.status_code)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -116,7 +116,30 @@ def load_native_plugins(cfg: Config):
|
|||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]:
|
||||||
|
logger.info(f"__scan_plugin_file:{file_path},{debug}")
|
||||||
|
loaded_plugins = []
|
||||||
|
if moduleList := inspect_zip_for_modules(str(file_path), debug):
|
||||||
|
for module in moduleList:
|
||||||
|
plugin = Path(plugin)
|
||||||
|
module = Path(module)
|
||||||
|
logger.debug(f"Plugin: {plugin} Module: {module}")
|
||||||
|
zipped_package = zipimporter(str(plugin))
|
||||||
|
zipped_module = zipped_package.load_module(str(module.parent))
|
||||||
|
for key in dir(zipped_module):
|
||||||
|
if key.startswith("__"):
|
||||||
|
continue
|
||||||
|
a_module = getattr(zipped_module, key)
|
||||||
|
a_keys = dir(a_module)
|
||||||
|
if (
|
||||||
|
"_abc_impl" in a_keys
|
||||||
|
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||||
|
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||||
|
):
|
||||||
|
loaded_plugins.append(a_module())
|
||||||
|
return loaded_plugins
|
||||||
|
|
||||||
|
def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||||
"""Scan the plugins directory for plugins and loads them.
|
"""Scan the plugins directory for plugins and loads them.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -126,31 +149,16 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
|
|||||||
Returns:
|
Returns:
|
||||||
List[Tuple[str, Path]]: List of plugins.
|
List[Tuple[str, Path]]: List of plugins.
|
||||||
"""
|
"""
|
||||||
loaded_plugins = []
|
|
||||||
current_dir = os.getcwd()
|
|
||||||
print(current_dir)
|
|
||||||
# Generic plugins
|
|
||||||
plugins_path_path = Path(PLUGINS_DIR)
|
|
||||||
|
|
||||||
for plugin in plugins_path_path.glob("*.zip"):
|
loaded_plugins = []
|
||||||
if moduleList := inspect_zip_for_modules(str(plugin), debug):
|
# Generic plugins
|
||||||
for module in moduleList:
|
plugins_path = Path(plugins_file_path)
|
||||||
plugin = Path(plugin)
|
if file_name:
|
||||||
module = Path(module)
|
plugin_path = Path(plugins_path, file_name)
|
||||||
logger.debug(f"Plugin: {plugin} Module: {module}")
|
loaded_plugins = __scan_plugin_file(plugin_path)
|
||||||
zipped_package = zipimporter(str(plugin))
|
else:
|
||||||
zipped_module = zipped_package.load_module(str(module.parent))
|
for plugin_path in plugins_path.glob("*.zip"):
|
||||||
for key in dir(zipped_module):
|
loaded_plugins.extend(__scan_plugin_file(plugin_path))
|
||||||
if key.startswith("__"):
|
|
||||||
continue
|
|
||||||
a_module = getattr(zipped_module, key)
|
|
||||||
a_keys = dir(a_module)
|
|
||||||
if (
|
|
||||||
"_abc_impl" in a_keys
|
|
||||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
|
||||||
# and denylist_allowlist_check(a_module.__name__, cfg)
|
|
||||||
):
|
|
||||||
loaded_plugins.append(a_module())
|
|
||||||
|
|
||||||
if loaded_plugins:
|
if loaded_plugins:
|
||||||
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||||
@@ -183,5 +191,55 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
|||||||
return ack.lower() == cfg.authorise_key
|
return ack.lower() == cfg.authorise_key
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main",
|
||||||
print(inspect_zip_for_modules("/Users/tuyang.yhj/Downloads/DB-GPT-Plugins-main (1).zip"))
|
authorization: str = "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"):
|
||||||
|
os.makedirs(download_path, exist_ok=True)
|
||||||
|
if github_repo:
|
||||||
|
if github_repo.index("github.com") <= 0:
|
||||||
|
raise ValueError("Not a correct Github repository address!" + github_repo)
|
||||||
|
github_repo = github_repo.replace(".git", "")
|
||||||
|
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
|
||||||
|
plugin_repo_name = github_repo.strip('/').split('/')[-1]
|
||||||
|
else:
|
||||||
|
url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
|
||||||
|
plugin_repo_name = "DB-GPT-Plugins"
|
||||||
|
try:
|
||||||
|
session = requests.Session()
|
||||||
|
response = session.get(
|
||||||
|
url,
|
||||||
|
headers={"Authorization": authorization},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
plugins_path_path = Path(download_path)
|
||||||
|
files = glob.glob(
|
||||||
|
os.path.join(plugins_path_path, f"{plugin_repo_name}*")
|
||||||
|
)
|
||||||
|
for file in files:
|
||||||
|
os.remove(file)
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||||
|
file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
|
||||||
|
print(file_name)
|
||||||
|
with open(file_name, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
return plugin_repo_name
|
||||||
|
else:
|
||||||
|
logger.error("update plugins faild,response code:", response.status_code)
|
||||||
|
raise ValueError("download plugin faild!" + response.status_code)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("update plugins from git exception!" + str(e))
|
||||||
|
raise ValueError("download plugin exception!", e)
|
||||||
|
|
||||||
|
|
||||||
|
def __fetch_from_git(local_path, git_url):
|
||||||
|
logger.info("fetch plugins from git to local path:{}", local_path)
|
||||||
|
os.makedirs(local_path, exist_ok=True)
|
||||||
|
repo = git.Repo(local_path)
|
||||||
|
if repo.is_repo():
|
||||||
|
repo.remotes.origin.pull()
|
||||||
|
else:
|
||||||
|
git.Repo.clone_from(git_url, local_path)
|
||||||
|
|
||||||
|
# if repo.head.is_valid():
|
||||||
|
# clone succ, fetch plugins info
|
@@ -56,6 +56,13 @@ class ChatScene(Enum):
|
|||||||
param_types=["Plugin Select"],
|
param_types=["Plugin Select"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ChatAgent = Scene(
|
||||||
|
code="chat_agent",
|
||||||
|
name="Agent Chat",
|
||||||
|
describe="Use tools through dialogue to accomplish your goals.",
|
||||||
|
param_types=["Plugin Select"],
|
||||||
|
)
|
||||||
|
|
||||||
InnerChatDBSummary = Scene(
|
InnerChatDBSummary = Scene(
|
||||||
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
|
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
|
||||||
)
|
)
|
||||||
|
0
pilot/scene/chat_agent/__init__.py
Normal file
0
pilot/scene/chat_agent/__init__.py
Normal file
72
pilot/scene/chat_agent/chat.py
Normal file
72
pilot/scene/chat_agent/chat.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.base_modules.agent.commands.command import execute_command
|
||||||
|
from pilot.base_modules.agent import PluginPromptGenerator
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatWithPlugin(BaseChat):
|
||||||
|
chat_scene: str = ChatScene.ChatAgent.value()
|
||||||
|
plugins_prompt_generator: PluginPromptGenerator
|
||||||
|
select_plugin: str = None
|
||||||
|
|
||||||
|
def __init__(self, chat_param: Dict):
|
||||||
|
self.plugin_selector = chat_param.select_param
|
||||||
|
chat_param["chat_mode"] = ChatScene.ChatAgent
|
||||||
|
super().__init__(chat_param=chat_param)
|
||||||
|
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||||
|
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||||
|
# 加载插件中可用命令
|
||||||
|
self.select_plugin = self.plugin_selector
|
||||||
|
if self.select_plugin:
|
||||||
|
for plugin in CFG.plugins:
|
||||||
|
if plugin._name == self.plugin_selector:
|
||||||
|
if not plugin.can_handle_post_prompt():
|
||||||
|
continue
|
||||||
|
self.plugins_prompt_generator = plugin.post_prompt(
|
||||||
|
self.plugins_prompt_generator
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for plugin in CFG.plugins:
|
||||||
|
if not plugin.can_handle_post_prompt():
|
||||||
|
continue
|
||||||
|
self.plugins_prompt_generator = plugin.post_prompt(
|
||||||
|
self.plugins_prompt_generator
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_input_values(self):
|
||||||
|
input_values = {
|
||||||
|
"input": self.current_user_input,
|
||||||
|
"constraints": self.__list_to_prompt_str(
|
||||||
|
list(self.plugins_prompt_generator.constraints)
|
||||||
|
),
|
||||||
|
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
|
||||||
|
}
|
||||||
|
return input_values
|
||||||
|
|
||||||
|
def do_action(self, prompt_response):
|
||||||
|
print(f"do_action:{prompt_response}")
|
||||||
|
## plugin command run
|
||||||
|
return execute_command(
|
||||||
|
str(prompt_response.command.get("name")),
|
||||||
|
prompt_response.command.get("args", {}),
|
||||||
|
self.plugins_prompt_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat_show(self):
|
||||||
|
super().chat_show()
|
||||||
|
|
||||||
|
def __list_to_prompt_str(self, list: List) -> str:
|
||||||
|
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||||
|
|
||||||
|
def generate(self, p) -> str:
|
||||||
|
return super().generate(p)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_type(self) -> str:
|
||||||
|
return ChatScene.ChatAgent.value
|
23
pilot/scene/chat_agent/example.py
Normal file
23
pilot/scene/chat_agent/example.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from pilot.prompts.example_base import ExampleSelector
|
||||||
|
|
||||||
|
## Two examples are defined by default
|
||||||
|
EXAMPLES = [
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"type": "human", "data": {"content": "查询xxx", "example": True}},
|
||||||
|
{
|
||||||
|
"type": "ai",
|
||||||
|
"data": {
|
||||||
|
"content": """{
|
||||||
|
\"thoughts\": \"thought text\",
|
||||||
|
\"speak\": \"thoughts summary to say to user\",
|
||||||
|
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
|
||||||
|
}""",
|
||||||
|
"example": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
46
pilot/scene/chat_agent/out_parser.py
Normal file
46
pilot/scene/chat_agent/out_parser.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import json
|
||||||
|
from typing import Dict, NamedTuple
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
|
from pilot.configs.model_config import LOGDIR
|
||||||
|
|
||||||
|
|
||||||
|
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginAction(NamedTuple):
|
||||||
|
command: Dict
|
||||||
|
speak: str = ""
|
||||||
|
thoughts: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class PluginChatOutputParser(BaseOutputParser):
|
||||||
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
|
clean_json_str = super().parse_prompt_response(model_out_text)
|
||||||
|
print(clean_json_str)
|
||||||
|
if not clean_json_str:
|
||||||
|
raise ValueError("model server response not have json!")
|
||||||
|
try:
|
||||||
|
response = json.loads(clean_json_str)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("model server out not fllow the prompt!")
|
||||||
|
|
||||||
|
speak = ""
|
||||||
|
thoughts = ""
|
||||||
|
for key in sorted(response):
|
||||||
|
if key.strip() == "command":
|
||||||
|
command = response[key]
|
||||||
|
if key.strip() == "thoughts":
|
||||||
|
thoughts = response[key]
|
||||||
|
if key.strip() == "speak":
|
||||||
|
speak = response[key]
|
||||||
|
return PluginAction(command, speak, thoughts)
|
||||||
|
|
||||||
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
### tool out data to table view
|
||||||
|
print(f"parse_view_response:{speak},{str(data)}")
|
||||||
|
view_text = f"##### {speak}" + "\n" + str(data)
|
||||||
|
return view_text
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
pass
|
78
pilot/scene/chat_agent/prompt.py
Normal file
78
pilot/scene/chat_agent/prompt.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import json
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.common.schema import SeparatorStyle, ExampleType
|
||||||
|
|
||||||
|
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||||
|
from pilot.scene.chat_execution.example import plugin_example
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
_PROMPT_SCENE_DEFINE_EN = "You are a universal AI assistant."
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE_EN = """
|
||||||
|
You need to analyze the user goals and, under the given constraints, prioritize using one of the following tools to solve the user goals.
|
||||||
|
Tool list:
|
||||||
|
{tool_list}
|
||||||
|
Constraint:
|
||||||
|
1. After selecting an available tool, please ensure that the output results include the following parts to use the tool:
|
||||||
|
<api-call><name>Selected Tool name</name> <arg1>Parameter value</arg1><arg2 >Parameter value 2</arg2></api-call>
|
||||||
|
2. If you cannot analyze the exact tool for the problem, you can consider using the search engine tool among the tools first.
|
||||||
|
3. Parameter content may need to be inferred based on the user's goals, not just extracted from text
|
||||||
|
4. If you cannot find a suitable tool, please answer Unable to complete the goal.
|
||||||
|
{expand_constraints}
|
||||||
|
User goals:
|
||||||
|
{user_goal}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PROMPT_SCENE_DEFINE_ZH = "你是一个通用AI助手!"
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE_ZH = """
|
||||||
|
请一步步思考,如何在满足下面约束条件的前提下,回答或解决用户问题或目标。
|
||||||
|
工具列表:
|
||||||
|
{tool_list}
|
||||||
|
约束条件:
|
||||||
|
1. 找到可用的工具后,请确保输出结果包含以下内容用来使用工具:<api-call><name>Selected Tool name</name> <arg1>Parameter value</arg1><arg2 >Parameter value 2</arg2></api-call>
|
||||||
|
2.任务重可以使用多个工具,上面约束的方式生成每个工具的调用,对于工具使用的提示文本,需要在工具使用前生成
|
||||||
|
3.如果有多个工具被使用,后续工具需要第一个工具的结果作为参数的, 使用如下文本来替代参数值:<api1-result>
|
||||||
|
4.如果对于问题无法理解和解决,可以考虑优先使用工具中的搜索引擎工具
|
||||||
|
5.参数内容可能需要根据用户的目标推理得到,不仅仅是从文本提取
|
||||||
|
6.如果中无法找到合适的工具,请回答无法完成目标。
|
||||||
|
7.约束条件和工具信息作为推理过程的辅助信息,不要表达在给用户的输出内容中
|
||||||
|
{expand_constraints}
|
||||||
|
用户目标:
|
||||||
|
{user_goal}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = (
|
||||||
|
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PROMPT_SCENE_DEFINE=(
|
||||||
|
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||||
|
)
|
||||||
|
|
||||||
|
RESPONSE_FORMAT = None
|
||||||
|
|
||||||
|
|
||||||
|
EXAMPLE_TYPE = ExampleType.ONE_SHOT
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
### Whether the model service is streaming output
|
||||||
|
PROMPT_NEED_STREAM_OUT = True
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ChatAgent.value(),
|
||||||
|
input_variables=["tool_list", "expand_constraints", "user_goal"],
|
||||||
|
response_format=None,
|
||||||
|
template_define=_PROMPT_SCENE_DEFINE,
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||||
|
output_parser=PluginChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
# example_selector=plugin_example,
|
||||||
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
@@ -1 +0,0 @@
|
|||||||
|
|
@@ -6,6 +6,7 @@ from typing import Optional, Any
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import PLUGINS_DIR
|
||||||
from pilot.componet import SystemApp
|
from pilot.componet import SystemApp
|
||||||
from pilot.utils.parameter_utils import BaseParameters
|
from pilot.utils.parameter_utils import BaseParameters
|
||||||
from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade
|
from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade
|
||||||
@@ -30,7 +31,7 @@ def async_db_summery(system_app: SystemApp):
|
|||||||
def server_init(args, system_app: SystemApp):
|
def server_init(args, system_app: SystemApp):
|
||||||
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
|
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
|
||||||
|
|
||||||
from pilot.base_modules.agent.plugins import scan_plugins
|
from pilot.base_modules.agent.plugins_util import scan_plugins
|
||||||
|
|
||||||
# logger.info(f"args: {args}")
|
# logger.info(f"args: {args}")
|
||||||
|
|
||||||
@@ -43,7 +44,7 @@ def server_init(args, system_app: SystemApp):
|
|||||||
# load_native_plugins(cfg)
|
# load_native_plugins(cfg)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode))
|
||||||
|
|
||||||
# Loader plugins and commands
|
# Loader plugins and commands
|
||||||
command_categories = [
|
command_categories = [
|
||||||
@@ -126,3 +127,4 @@ class WebWerverParameters(BaseParameters):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
|
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user