diff --git a/pilot/base_modules/agent/agent.py b/pilot/base_modules/agent/agent.py
deleted file mode 100644
index d11f21ab5..000000000
--- a/pilot/base_modules/agent/agent.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from abc import ABC, abstractmethod
-
-
-
-
-
-
-class AgentFacade(ABC):
- def __init__(self) -> None:
- self.model = None
-
-
-
diff --git a/pilot/base_modules/agent/commands/command.py b/pilot/base_modules/agent/commands/command.py
index b4cb0c0f6..0949ddb82 100644
--- a/pilot/base_modules/agent/commands/command.py
+++ b/pilot/base_modules/agent/commands/command.py
@@ -27,7 +27,6 @@ def execute_ai_response_json(
user_input: str = None,
) -> str:
"""
-
Args:
command_registry:
ai_response:
@@ -65,6 +64,8 @@ def execute_ai_response_json(
return result
+
+
def execute_command(
command_name: str,
arguments,
diff --git a/pilot/base_modules/agent/controller.py b/pilot/base_modules/agent/controller.py
index cbf0e3b7b..f6ecd2174 100644
--- a/pilot/base_modules/agent/controller.py
+++ b/pilot/base_modules/agent/controller.py
@@ -3,6 +3,8 @@ import time
from fastapi import (
APIRouter,
Body,
+ UploadFile,
+ File,
)
from typing import List
@@ -13,14 +15,77 @@ from pilot.openapi.api_view_model import (
Result,
)
-
+from .model import PluginHubParam, PagenationFilter, PagenationResult
+from .hub.agent_hub import AgentHub
+from .db.plugin_hub_db import PluginHubEntity
+from .db.my_plugin_db import MyPluginEntity
+from pilot.configs.model_config import PLUGINS_DIR
router = APIRouter()
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
-@router.get("/v1/mange/agent/list", response_model=Result[str])
-async def get_agent_list():
- logger.info(f"get_agent_list!")
-
+@router.post("/v1/agent/hub/update", response_model=Result[str])
+async def agent_hub_update(update_param: PluginHubParam = Body()):
+ logger.info(f"agent_hub_update:{update_param.__dict__}")
+ agent_hub = AgentHub(PLUGINS_DIR)
+ agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization)
return Result.succ(None)
+
+
+@router.post("/v1/agent/query", response_model=Result[str])
+async def get_agent_list(filter: PagenationFilter[PluginHubEntity] = Body()):
+ logger.info(f"get_agent_list:{json.dumps(filter)}")
+ agent_hub = AgentHub(PLUGINS_DIR)
+ datas, total_pages, total_count = agent_hub.hub_dao.list(filter.filter, filter.page_index, filter.page_size)
+ result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
+ result.page_index = filter.page_index
+ result.page_size = filter.page_size
+ result.total_page = total_pages
+ result.total_row_count = total_count
+ result.datas = datas
+ return Result.succ(result)
+
+@router.post("/v1/agent/my", response_model=Result[str])
+async def my_agents(user:str= None):
+ logger.info(f"my_agents:{json.dumps(my_agents)}")
+ agent_hub = AgentHub(PLUGINS_DIR)
+ return Result.succ(agent_hub.get_my_plugin(user))
+
+
+@router.post("/v1/agent/install", response_model=Result[str])
+async def agent_install(plugin_name:str, user: str = None):
+ logger.info(f"agent_install:{plugin_name},{user}")
+ try:
+ agent_hub = AgentHub(PLUGINS_DIR)
+ agent_hub.install_plugin(plugin_name, user)
+ return Result.succ(None)
+ except Exception as e:
+ logger.error("Plugin Install Error!", e)
+ return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
+
+
+
+@router.post("/v1/agent/uninstall", response_model=Result[str])
+async def agent_uninstall(plugin_name:str, user: str = None):
+ logger.info(f"agent_uninstall:{plugin_name},{user}")
+ try:
+ agent_hub = AgentHub(PLUGINS_DIR)
+ agent_hub.uninstall_plugin(plugin_name, user)
+ return Result.succ(None)
+ except Exception as e:
+ logger.error("Plugin Uninstall Error!", e)
+ return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
+
+
+@router.post("/v1/personal/agent/upload", response_model=Result[str])
+async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
+ logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
+ try:
+ agent_hub = AgentHub(PLUGINS_DIR)
+ agent_hub.upload_my_plugin(doc_file, user)
+ return Result.succ(None)
+ except Exception as e:
+ logger.error("Upload Personal Plugin Error!", e)
+ return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")
+
diff --git a/pilot/base_modules/agent/db/my_plugin_db.py b/pilot/base_modules/agent/db/my_plugin_db.py
index 2e2a73ed9..545866854 100644
--- a/pilot/base_modules/agent/db/my_plugin_db.py
+++ b/pilot/base_modules/agent/db/my_plugin_db.py
@@ -18,6 +18,7 @@ class MyPluginEntity(Base):
user_code = Column(String, nullable=True, comment="user code")
user_name = Column(String, nullable=True, comment="user name")
name = Column(String, unique=True, nullable=False, comment="plugin name")
+ file_name = Column(String, nullable=False, comment="plugin package file name")
type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
@@ -74,6 +75,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
session = self.Session()
my_plugins = session.query(MyPluginEntity)
+ all_count = my_plugins.count()
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
@@ -101,7 +103,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
result = my_plugins.all()
session.close()
- return result
+ total_pages = all_count // page_size
+ if all_count % page_size != 0:
+ total_pages += 1
+
+
+ return result, total_pages, all_count
+
def count(self, query: MyPluginEntity):
session = self.Session()
diff --git a/pilot/base_modules/agent/db/plugin_hub_db.py b/pilot/base_modules/agent/db/plugin_hub_db.py
index 1dfb2363d..b579f5071 100644
--- a/pilot/base_modules/agent/db/plugin_hub_db.py
+++ b/pilot/base_modules/agent/db/plugin_hub_db.py
@@ -13,12 +13,14 @@ class PluginHubEntity(Base):
__tablename__ = 'plugin_hub'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
name = Column(String, unique=True, nullable=False, comment="plugin name")
+ description = Column(String, nullable=False, comment="plugin description")
author = Column(String, nullable=True, comment="plugin author")
email = Column(String, nullable=True, comment="plugin author email")
type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version")
storage_channel = Column(String, comment="plugin storage channel")
storage_url = Column(String, comment="plugin download url")
+ download_param = Column(String, comment="plugin download param")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
installed = Column(Boolean, default=False, comment="plugin already installed")
@@ -61,6 +63,8 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
session = self.Session()
plugin_hubs = session.query(PluginHubEntity)
+ all_count = plugin_hubs.count()
+
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
@@ -84,6 +88,20 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
result = plugin_hubs.all()
session.close()
+
+ total_pages = all_count // page_size
+ if all_count % page_size != 0:
+ total_pages += 1
+
+
+ return result, total_pages, all_count
+
+ def get_by_storage_url(self, storage_url):
+ session = self.Session()
+ plugin_hubs = session.query(PluginHubEntity)
+ plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
+ result = plugin_hubs.all()
+ session.close()
return result
def get_by_name(self, name: str) -> PluginHubEntity:
diff --git a/pilot/base_modules/agent/hub/agent_hub.py b/pilot/base_modules/agent/hub/agent_hub.py
index 55bc1c7a6..2d2daaf50 100644
--- a/pilot/base_modules/agent/hub/agent_hub.py
+++ b/pilot/base_modules/agent/hub/agent_hub.py
@@ -1,39 +1,54 @@
+import json
import logging
-import git
import os
+import glob
+import zipfile
+import requests
+from pathlib import Path
+import datetime
+import shutil
+from fastapi import UploadFile
+from typing import Any
+import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from .schema import PluginStorageType
+from ..plugins_util import scan_plugins, update_from_git
logger = logging.getLogger("agent_hub")
Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = ""
+
class AgentHub:
- def __init__(self, temp_hub_file_path:str = "") -> None:
+ def __init__(self, plugin_dir) -> None:
self.hub_dao = PluginHubDao()
self.my_lugin_dao = MyPluginDao()
- if temp_hub_file_path:
- self.temp_hub_file_path = temp_hub_file_path
- else:
- self.temp_hub_file_path = os.path.join(os.getcwd(), "plugins", "temp")
+ os.makedirs(plugin_dir, exist_ok=True)
+ self.plugin_dir = plugin_dir
+ self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
def install_plugin(self, plugin_name: str, user_name: str = None):
logger.info(f"install_plugin {plugin_name}")
-
plugin_entity = self.hub_dao.get_by_name(plugin_name)
if plugin_entity:
if plugin_entity.storage_channel == PluginStorageType.Git.value:
try:
- self.__download_from_git(plugin_name, plugin_entity.storage_url)
- self.load_plugin(plugin_name)
+ branch_name = None
+ authorization = None
+ if plugin_entity.download_param:
+ download_param = json.loads(plugin_entity.download_param)
+ branch_name = download_param.get("branch_name")
+ authorization = download_param.get("authorization")
+ file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization)
# add to my plugins and edit hub status
plugin_entity.installed = True
my_plugin_entity = self.__build_my_plugin(plugin_entity)
+ my_plugin_entity.file_name = file_name
if user_name:
# TODO use user
my_plugin_entity.user_code = ""
@@ -55,9 +70,43 @@ class AgentHub:
else:
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
+
else:
raise ValueError(f"Can't Find Plugin {plugin_name}!")
+ def uninstall_plugin(self, plugin_name, user):
+ logger.info(f"uninstall_plugin:{plugin_name},{user}")
+ plugin_entity = self.hub_dao.get_by_name(plugin_name)
+ plugin_entity.installed = False
+ with self.hub_dao.Session() as session:
+ try:
+ my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
+ if user:
+ my_plugin_q.filter(MyPluginEntity.user_code == user)
+ my_plugin_q.delete()
+ session.merge(plugin_entity)
+ session.commit()
+ except:
+ session.rollback()
+
+ # delete package file if not use
+ plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
+ have_installed = False
+ for plugin_info in plugin_infos:
+ if plugin_info.installed:
+ have_installed = True
+ break
+ if not have_installed:
+ plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
+ files = glob.glob(
+ os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
+ )
+ for file in files:
+ os.remove(file)
+
+ def __download_from_git(self, github_repo, branch_name, authorization):
+ return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
+
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
my_plugin_entity = MyPluginEntity()
my_plugin_entity.name = hub_plugin.name
@@ -65,29 +114,68 @@ class AgentHub:
my_plugin_entity.version = hub_plugin.version
return my_plugin_entity
- def __fetch_from_git(self):
- logger.info("fetch plugins from git to local path:{}", self.temp_hub_file_path)
- os.makedirs(self.temp_hub_file_path, exist_ok=True)
- repo = git.Repo(self.temp_hub_file_path)
- if repo.is_repo():
- repo.remotes.origin.pull()
- else:
- git.Repo.clone_from(DEFAULT_PLUGIN_REPO, self.temp_hub_file_path)
+ def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None):
+ logger.info("refresh_hub_by_git start!")
+ update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization)
+ git_plugins = scan_plugins(self.temp_hub_file_path)
+ for git_plugin in git_plugins:
+ old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
+ if old_hub_info:
+ plugin_hub_info = old_hub_info
+ else:
+ plugin_hub_info = PluginHubEntity()
+ plugin_hub_info.type = ""
+ plugin_hub_info.storage_channel = PluginStorageType.Git.value
+ plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
+ plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT')
+ plugin_hub_info.email = getattr(git_plugin, '_email', '')
+ plugin_hub_info.download_param = json.dumps({
+ branch_name: branch_name,
+ authorization: authorization
+ })
+ plugin_hub_info.installed = False
- # if repo.head.is_valid():
- # clone succ, fetch plugins info
+ plugin_hub_info.name = git_plugin._name
+ plugin_hub_info.version = git_plugin._version
+ plugin_hub_info.description = git_plugin._description
+ self.hub_dao.update(plugin_hub_info)
+
+ def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User):
+
+ # We can not move temp file in windows system when we open file in context of `with`
+ file_path = os.path.join(self.plugin_dir, doc_file.filename)
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ tmp_fd, tmp_path = tempfile.mkstemp(
+ dir=os.path.join(self.plugin_dir)
+ )
+ with os.fdopen(tmp_fd, "wb") as tmp:
+ tmp.write(await doc_file.read())
+ shutil.move(
+ tmp_path,
+ os.path.join(self.plugin_dir, doc_file.filename),
+ )
+
+ my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
+ for my_plugin in my_plugins:
+ my_plugin_entiy = MyPluginEntity()
+
+ my_plugin_entiy.name = my_plugin._name
+ my_plugin_entiy.version = my_plugin._version
+ my_plugin_entiy.type = "Personal"
+ my_plugin_entiy.user_code = user
+ my_plugin_entiy.user_name = user
+ my_plugin_entiy.tenant = ""
+ my_plugin_entiy.file_name = doc_file.filename
+
+ self.my_lugin_dao.update(my_plugin_entiy)
- def upload_plugin_in_hub(self, name: str, path: str):
- pass
- def __download_from_git(self, plugin_name, url):
- pass
-
- def load_plugin(self, plugin_name):
- logger.info(f"load_plugin:{plugin_name}")
- pass
+ def reload_my_plugins(self):
+ logger.info(f"load_plugins start!")
+ return scan_plugins(self.plugin_dir)
def get_my_plugin(self, user: str):
logger.info(f"get_my_plugin:{user}")
@@ -95,7 +183,3 @@ class AgentHub:
user = Default_User
return self.my_lugin_dao.get_by_user(user)
- def uninstall_plugin(self, plugin_name, user):
- logger.info(f"uninstall_plugin:{plugin_name},{user}")
-
- pass
diff --git a/pilot/base_modules/agent/model.py b/pilot/base_modules/agent/model.py
new file mode 100644
index 000000000..2d268e138
--- /dev/null
+++ b/pilot/base_modules/agent/model.py
@@ -0,0 +1,30 @@
+from typing import TypedDict, Optional, Dict, List
+from dataclasses import dataclass
+from pydantic import BaseModel, Field
+from typing import TypeVar, Generic, Any
+
+T = TypeVar('T')
+
+class PagenationFilter(Generic[T]):
+ page_index: int = 1
+ page_size: int = 20
+ filter: T = None
+
+class PagenationResult(Generic[T]):
+ page_index: int = 1
+ page_size: int = 20
+ total_page: int = 0
+ total_row_count: int = 0
+ datas: List[T] = []
+
+
+
+@dataclass
+class PluginHubParam:
+ channel: str = Field(..., description="Plugin storage channel")
+ url: str = Field(..., description="Plugin storage url")
+
+ branch: str = Field(..., description="When the storage channel is github, use to specify the branch", nullable=True)
+ authorization: str = Field(..., description="github download authorization", nullable=True)
+
+
diff --git a/pilot/base_modules/agent/module.py b/pilot/base_modules/agent/module.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/base_modules/agent/plugins_loader.py b/pilot/base_modules/agent/plugins_loader.py
new file mode 100644
index 000000000..2d5c55a8e
--- /dev/null
+++ b/pilot/base_modules/agent/plugins_loader.py
@@ -0,0 +1,2 @@
+
+class PluginLoader():
diff --git a/pilot/base_modules/agent/plugins.py b/pilot/base_modules/agent/plugins_util.py
similarity index 60%
rename from pilot/base_modules/agent/plugins.py
rename to pilot/base_modules/agent/plugins_util.py
index f1734e43b..5a4f86229 100644
--- a/pilot/base_modules/agent/plugins.py
+++ b/pilot/base_modules/agent/plugins_util.py
@@ -5,6 +5,7 @@ import os
import glob
import zipfile
import requests
+import git
import threading
import datetime
from pathlib import Path
@@ -20,7 +21,6 @@ from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger
-
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
"""
Loader zip plugin file. Native support Auto_gpt_plugin
@@ -106,7 +106,7 @@ def load_native_plugins(cfg: Config):
with open(file_name, "wb") as f:
f.write(response.content)
print("save file")
- cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
+ cfg.set_plugins(scan_plugins(cfg.debug_mode))
else:
print("get file faild,response code:", response.status_code)
except Exception as e:
@@ -116,7 +116,30 @@ def load_native_plugins(cfg: Config):
t.start()
-def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
+def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]:
+ logger.info(f"__scan_plugin_file:{file_path},{debug}")
+ loaded_plugins = []
+ if moduleList := inspect_zip_for_modules(str(file_path), debug):
+ for module in moduleList:
+ plugin = Path(plugin)
+ module = Path(module)
+ logger.debug(f"Plugin: {plugin} Module: {module}")
+ zipped_package = zipimporter(str(plugin))
+ zipped_module = zipped_package.load_module(str(module.parent))
+ for key in dir(zipped_module):
+ if key.startswith("__"):
+ continue
+ a_module = getattr(zipped_module, key)
+ a_keys = dir(a_module)
+ if (
+ "_abc_impl" in a_keys
+ and a_module.__name__ != "AutoGPTPluginTemplate"
+ # and denylist_allowlist_check(a_module.__name__, cfg)
+ ):
+ loaded_plugins.append(a_module())
+ return loaded_plugins
+
+def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.
Args:
@@ -126,31 +149,16 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
Returns:
List[Tuple[str, Path]]: List of plugins.
"""
- loaded_plugins = []
- current_dir = os.getcwd()
- print(current_dir)
- # Generic plugins
- plugins_path_path = Path(PLUGINS_DIR)
- for plugin in plugins_path_path.glob("*.zip"):
- if moduleList := inspect_zip_for_modules(str(plugin), debug):
- for module in moduleList:
- plugin = Path(plugin)
- module = Path(module)
- logger.debug(f"Plugin: {plugin} Module: {module}")
- zipped_package = zipimporter(str(plugin))
- zipped_module = zipped_package.load_module(str(module.parent))
- for key in dir(zipped_module):
- if key.startswith("__"):
- continue
- a_module = getattr(zipped_module, key)
- a_keys = dir(a_module)
- if (
- "_abc_impl" in a_keys
- and a_module.__name__ != "AutoGPTPluginTemplate"
- # and denylist_allowlist_check(a_module.__name__, cfg)
- ):
- loaded_plugins.append(a_module())
+ loaded_plugins = []
+ # Generic plugins
+ plugins_path = Path(plugins_file_path)
+ if file_name:
+ plugin_path = Path(plugins_path, file_name)
+ loaded_plugins = __scan_plugin_file(plugin_path)
+ else:
+ for plugin_path in plugins_path.glob("*.zip"):
+ loaded_plugins.extend(__scan_plugin_file(plugin_path))
if loaded_plugins:
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
@@ -183,5 +191,55 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
return ack.lower() == cfg.authorise_key
-if __name__ == '__main__':
- print(inspect_zip_for_modules("/Users/tuyang.yhj/Downloads/DB-GPT-Plugins-main (1).zip"))
+def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main",
+ authorization: str = "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"):
+ os.makedirs(download_path, exist_ok=True)
+ if github_repo:
+ if github_repo.index("github.com") <= 0:
+ raise ValueError("Not a correct Github repository address!" + github_repo)
+ github_repo = github_repo.replace(".git", "")
+ url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
+ plugin_repo_name = github_repo.strip('/').split('/')[-1]
+ else:
+ url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
+ plugin_repo_name = "DB-GPT-Plugins"
+ try:
+ session = requests.Session()
+ response = session.get(
+ url,
+ headers={"Authorization": authorization},
+ )
+
+ if response.status_code == 200:
+ plugins_path_path = Path(download_path)
+ files = glob.glob(
+ os.path.join(plugins_path_path, f"{plugin_repo_name}*")
+ )
+ for file in files:
+ os.remove(file)
+ now = datetime.datetime.now()
+ time_str = now.strftime("%Y%m%d%H%M%S")
+ file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
+ print(file_name)
+ with open(file_name, "wb") as f:
+ f.write(response.content)
+ return plugin_repo_name
+ else:
+ logger.error("update plugins faild,response code:", response.status_code)
+ raise ValueError("download plugin faild!" + response.status_code)
+ except Exception as e:
+ logger.error("update plugins from git exception!" + str(e))
+ raise ValueError("download plugin exception!", e)
+
+
+def __fetch_from_git(local_path, git_url):
+ logger.info("fetch plugins from git to local path:{}", local_path)
+ os.makedirs(local_path, exist_ok=True)
+ repo = git.Repo(local_path)
+ if repo.is_repo():
+ repo.remotes.origin.pull()
+ else:
+ git.Repo.clone_from(git_url, local_path)
+
+ # if repo.head.is_valid():
+ # clone succ, fetch plugins info
diff --git a/pilot/scene/base.py b/pilot/scene/base.py
index 162759e3c..b56176991 100644
--- a/pilot/scene/base.py
+++ b/pilot/scene/base.py
@@ -56,6 +56,13 @@ class ChatScene(Enum):
param_types=["Plugin Select"],
)
+ ChatAgent = Scene(
+ code="chat_agent",
+ name="Agent Chat",
+ describe="Use tools through dialogue to accomplish your goals.",
+ param_types=["Plugin Select"],
+ )
+
InnerChatDBSummary = Scene(
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
)
diff --git a/pilot/scene/chat_agent/__init__.py b/pilot/scene/chat_agent/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py
new file mode 100644
index 000000000..9a35c595b
--- /dev/null
+++ b/pilot/scene/chat_agent/chat.py
@@ -0,0 +1,72 @@
+from typing import List, Dict
+
+from pilot.scene.base_chat import BaseChat
+from pilot.scene.base import ChatScene
+from pilot.configs.config import Config
+from pilot.base_modules.agent.commands.command import execute_command
+from pilot.base_modules.agent import PluginPromptGenerator
+
+CFG = Config()
+
+
+class ChatWithPlugin(BaseChat):
+ chat_scene: str = ChatScene.ChatAgent.value()
+ plugins_prompt_generator: PluginPromptGenerator
+ select_plugin: str = None
+
+ def __init__(self, chat_param: Dict):
+ self.plugin_selector = chat_param.select_param
+ chat_param["chat_mode"] = ChatScene.ChatAgent
+ super().__init__(chat_param=chat_param)
+ self.plugins_prompt_generator = PluginPromptGenerator()
+ self.plugins_prompt_generator.command_registry = CFG.command_registry
+ # 加载插件中可用命令
+ self.select_plugin = self.plugin_selector
+ if self.select_plugin:
+ for plugin in CFG.plugins:
+ if plugin._name == self.plugin_selector:
+ if not plugin.can_handle_post_prompt():
+ continue
+ self.plugins_prompt_generator = plugin.post_prompt(
+ self.plugins_prompt_generator
+ )
+
+ else:
+ for plugin in CFG.plugins:
+ if not plugin.can_handle_post_prompt():
+ continue
+ self.plugins_prompt_generator = plugin.post_prompt(
+ self.plugins_prompt_generator
+ )
+
+ def generate_input_values(self):
+ input_values = {
+ "input": self.current_user_input,
+ "constraints": self.__list_to_prompt_str(
+ list(self.plugins_prompt_generator.constraints)
+ ),
+ "commands_infos": self.plugins_prompt_generator.generate_commands_string(),
+ }
+ return input_values
+
+ def do_action(self, prompt_response):
+ print(f"do_action:{prompt_response}")
+ ## plugin command run
+ return execute_command(
+ str(prompt_response.command.get("name")),
+ prompt_response.command.get("args", {}),
+ self.plugins_prompt_generator,
+ )
+
+ def chat_show(self):
+ super().chat_show()
+
+ def __list_to_prompt_str(self, list: List) -> str:
+ return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
+
+ def generate(self, p) -> str:
+ return super().generate(p)
+
+ @property
+ def chat_type(self) -> str:
+ return ChatScene.ChatAgent.value
diff --git a/pilot/scene/chat_agent/example.py b/pilot/scene/chat_agent/example.py
new file mode 100644
index 000000000..d4b84d757
--- /dev/null
+++ b/pilot/scene/chat_agent/example.py
@@ -0,0 +1,23 @@
+from pilot.prompts.example_base import ExampleSelector
+
+## Two examples are defined by default
+EXAMPLES = [
+ {
+ "messages": [
+ {"type": "human", "data": {"content": "查询xxx", "example": True}},
+ {
+ "type": "ai",
+ "data": {
+ "content": """{
+ \"thoughts\": \"thought text\",
+ \"speak\": \"thoughts summary to say to user\",
+ \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
+ }""",
+ "example": True,
+ },
+ },
+ ]
+ },
+]
+
+plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
diff --git a/pilot/scene/chat_agent/out_parser.py b/pilot/scene/chat_agent/out_parser.py
new file mode 100644
index 000000000..2078018d0
--- /dev/null
+++ b/pilot/scene/chat_agent/out_parser.py
@@ -0,0 +1,46 @@
+import json
+from typing import Dict, NamedTuple
+from pilot.utils import build_logger
+from pilot.out_parser.base import BaseOutputParser, T
+from pilot.configs.model_config import LOGDIR
+
+
+logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
+
+
+class PluginAction(NamedTuple):
+ command: Dict
+ speak: str = ""
+ thoughts: str = ""
+
+
+class PluginChatOutputParser(BaseOutputParser):
+ def parse_prompt_response(self, model_out_text) -> T:
+ clean_json_str = super().parse_prompt_response(model_out_text)
+ print(clean_json_str)
+ if not clean_json_str:
+ raise ValueError("model server response not have json!")
+ try:
+ response = json.loads(clean_json_str)
+ except Exception as e:
+ raise ValueError("model server out not fllow the prompt!")
+
+ speak = ""
+ thoughts = ""
+ for key in sorted(response):
+ if key.strip() == "command":
+ command = response[key]
+ if key.strip() == "thoughts":
+ thoughts = response[key]
+ if key.strip() == "speak":
+ speak = response[key]
+ return PluginAction(command, speak, thoughts)
+
+ def parse_view_response(self, speak, data) -> str:
+ ### tool out data to table view
+ print(f"parse_view_response:{speak},{str(data)}")
+ view_text = f"##### {speak}" + "\n" + str(data)
+ return view_text
+
+ def get_format_instructions(self) -> str:
+ pass
diff --git a/pilot/scene/chat_agent/prompt.py b/pilot/scene/chat_agent/prompt.py
new file mode 100644
index 000000000..b6be5a2c3
--- /dev/null
+++ b/pilot/scene/chat_agent/prompt.py
@@ -0,0 +1,78 @@
+import json
+from pilot.prompts.prompt_new import PromptTemplate
+from pilot.configs.config import Config
+from pilot.scene.base import ChatScene
+from pilot.common.schema import SeparatorStyle, ExampleType
+
+from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
+from pilot.scene.chat_execution.example import plugin_example
+
+CFG = Config()
+
+_PROMPT_SCENE_DEFINE_EN = "You are a universal AI assistant."
+
+_DEFAULT_TEMPLATE_EN = """
+You need to analyze the user goals and, under the given constraints, prioritize using one of the following tools to solve the user goals.
+Tool list:
+ {tool_list}
+Constraint:
+ 1. After selecting an available tool, please ensure that the output results include the following parts to use the tool:
+ Selected Tool name Parameter valueParameter value 2
+ 2. If you cannot analyze the exact tool for the problem, you can consider using the search engine tool among the tools first.
+ 3. Parameter content may need to be inferred based on the user's goals, not just extracted from text
+ 4. If you cannot find a suitable tool, please answer Unable to complete the goal.
+ {expand_constraints}
+User goals:
+ {user_goal}
+"""
+
+_PROMPT_SCENE_DEFINE_ZH = "你是一个通用AI助手!"
+
+_DEFAULT_TEMPLATE_ZH = """
+请一步步思考,如何在满足下面约束条件的前提下,回答或解决用户问题或目标。
+工具列表:
+ {tool_list}
+约束条件:
+ 1. 找到可用的工具后,请确保输出结果包含以下内容用来使用工具:Selected Tool name Parameter valueParameter value 2
+ 2.任务重可以使用多个工具,上面约束的方式生成每个工具的调用,对于工具使用的提示文本,需要在工具使用前生成
+ 3.如果有多个工具被使用,后续工具需要第一个工具的结果作为参数的, 使用如下文本来替代参数值:
+ 4.如果对于问题无法理解和解决,可以考虑优先使用工具中的搜索引擎工具
+ 5.参数内容可能需要根据用户的目标推理得到,不仅仅是从文本提取
+ 6.如果中无法找到合适的工具,请回答无法完成目标。
+ 7.约束条件和工具信息作为推理过程的辅助信息,不要表达在给用户的输出内容中
+ {expand_constraints}
+用户目标:
+ {user_goal}
+"""
+
+_DEFAULT_TEMPLATE = (
+ _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
+)
+
+
+_PROMPT_SCENE_DEFINE=(
+ _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
+)
+
+RESPONSE_FORMAT = None
+
+
+EXAMPLE_TYPE = ExampleType.ONE_SHOT
+PROMPT_SEP = SeparatorStyle.SINGLE.value
+### Whether the model service is streaming output
+PROMPT_NEED_STREAM_OUT = True
+
+prompt = PromptTemplate(
+ template_scene=ChatScene.ChatAgent.value(),
+ input_variables=["tool_list", "expand_constraints", "user_goal"],
+ response_format=None,
+ template_define=_PROMPT_SCENE_DEFINE,
+ template=_DEFAULT_TEMPLATE,
+ stream_out=PROMPT_NEED_STREAM_OUT,
+ output_parser=PluginChatOutputParser(
+ sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
+ ),
+ # example_selector=plugin_example,
+)
+
+CFG.prompt_template_registry.register(prompt, is_default=True)
diff --git a/pilot/scene/chat_execution/prompt_v2.py b/pilot/scene/chat_execution/prompt_v2.py
deleted file mode 100644
index 8b1378917..000000000
--- a/pilot/scene/chat_execution/prompt_v2.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/pilot/server/base.py b/pilot/server/base.py
index ec08fe074..d363e05bd 100644
--- a/pilot/server/base.py
+++ b/pilot/server/base.py
@@ -6,6 +6,7 @@ from typing import Optional, Any
from dataclasses import dataclass, field
from pilot.configs.config import Config
+from pilot.configs.model_config import PLUGINS_DIR
from pilot.componet import SystemApp
from pilot.utils.parameter_utils import BaseParameters
from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade
@@ -30,7 +31,7 @@ def async_db_summery(system_app: SystemApp):
def server_init(args, system_app: SystemApp):
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
- from pilot.base_modules.agent.plugins import scan_plugins
+ from pilot.base_modules.agent.plugins_util import scan_plugins
# logger.info(f"args: {args}")
@@ -43,7 +44,7 @@ def server_init(args, system_app: SystemApp):
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
- cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
+ cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode))
# Loader plugins and commands
command_categories = [
@@ -126,3 +127,4 @@ class WebWerverParameters(BaseParameters):
},
)
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
+