diff --git a/pilot/base_modules/agent/agent.py b/pilot/base_modules/agent/agent.py deleted file mode 100644 index d11f21ab5..000000000 --- a/pilot/base_modules/agent/agent.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import ABC, abstractmethod - - - - - - -class AgentFacade(ABC): - def __init__(self) -> None: - self.model = None - - - diff --git a/pilot/base_modules/agent/commands/command.py b/pilot/base_modules/agent/commands/command.py index b4cb0c0f6..0949ddb82 100644 --- a/pilot/base_modules/agent/commands/command.py +++ b/pilot/base_modules/agent/commands/command.py @@ -27,7 +27,6 @@ def execute_ai_response_json( user_input: str = None, ) -> str: """ - Args: command_registry: ai_response: @@ -65,6 +64,8 @@ def execute_ai_response_json( return result + + def execute_command( command_name: str, arguments, diff --git a/pilot/base_modules/agent/controller.py b/pilot/base_modules/agent/controller.py index cbf0e3b7b..f6ecd2174 100644 --- a/pilot/base_modules/agent/controller.py +++ b/pilot/base_modules/agent/controller.py @@ -3,6 +3,8 @@ import time from fastapi import ( APIRouter, Body, + UploadFile, + File, ) from typing import List @@ -13,14 +15,77 @@ from pilot.openapi.api_view_model import ( Result, ) - +from .model import PluginHubParam, PagenationFilter, PagenationResult +from .hub.agent_hub import AgentHub +from .db.plugin_hub_db import PluginHubEntity +from .db.my_plugin_db import MyPluginEntity +from pilot.configs.model_config import PLUGINS_DIR router = APIRouter() logger = build_logger("agent_mange", LOGDIR + "agent_mange.log") -@router.get("/v1/mange/agent/list", response_model=Result[str]) -async def get_agent_list(): - logger.info(f"get_agent_list!") - +@router.post("/v1/agent/hub/update", response_model=Result[str]) +async def agent_hub_update(update_param: PluginHubParam = Body()): + logger.info(f"agent_hub_update:{update_param.__dict__}") + agent_hub = AgentHub(PLUGINS_DIR) + agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization) return Result.succ(None) + + +@router.post("/v1/agent/query", response_model=Result[str]) +async def get_agent_list(filter: PagenationFilter[PluginHubEntity] = Body()): + logger.info(f"get_agent_list:{json.dumps(filter)}") + agent_hub = AgentHub(PLUGINS_DIR) + datas, total_pages, total_count = agent_hub.hub_dao.list(filter.filter, filter.page_index, filter.page_size) + result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]() + result.page_index = filter.page_index + result.page_size = filter.page_size + result.total_page = total_pages + result.total_row_count = total_count + result.datas = datas + return Result.succ(result) + +@router.post("/v1/agent/my", response_model=Result[str]) +async def my_agents(user:str= None): + logger.info(f"my_agents:{json.dumps(my_agents)}") + agent_hub = AgentHub(PLUGINS_DIR) + return Result.succ(agent_hub.get_my_plugin(user)) + + +@router.post("/v1/agent/install", response_model=Result[str]) +async def agent_install(plugin_name:str, user: str = None): + logger.info(f"agent_install:{plugin_name},{user}") + try: + agent_hub = AgentHub(PLUGINS_DIR) + agent_hub.install_plugin(plugin_name, user) + return Result.succ(None) + except Exception as e: + logger.error("Plugin Install Error!", e) + return Result.faild(code="E0021", msg=f"Plugin Install Error {e}") + + + +@router.post("/v1/agent/uninstall", response_model=Result[str]) +async def agent_uninstall(plugin_name:str, user: str = None): + logger.info(f"agent_uninstall:{plugin_name},{user}") + try: + agent_hub = AgentHub(PLUGINS_DIR) + agent_hub.uninstall_plugin(plugin_name, user) + return Result.succ(None) + except Exception as e: + logger.error("Plugin Uninstall Error!", e) + return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}") + + +@router.post("/v1/personal/agent/upload", response_model=Result[str]) +async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None): + logger.info(f"personal_agent_upload:{doc_file.filename},{user}") + try: + agent_hub = AgentHub(PLUGINS_DIR) + agent_hub.upload_my_plugin(doc_file, user) + return Result.succ(None) + except Exception as e: + logger.error("Upload Personal Plugin Error!", e) + return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}") + diff --git a/pilot/base_modules/agent/db/my_plugin_db.py b/pilot/base_modules/agent/db/my_plugin_db.py index 2e2a73ed9..545866854 100644 --- a/pilot/base_modules/agent/db/my_plugin_db.py +++ b/pilot/base_modules/agent/db/my_plugin_db.py @@ -18,6 +18,7 @@ class MyPluginEntity(Base): user_code = Column(String, nullable=True, comment="user code") user_name = Column(String, nullable=True, comment="user name") name = Column(String, unique=True, nullable=False, comment="plugin name") + file_name = Column(String, nullable=False, comment="plugin package file name") type = Column(String, comment="plugin type") version = Column(String, comment="plugin version") use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count") @@ -74,6 +75,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]): def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]: session = self.Session() my_plugins = session.query(MyPluginEntity) + all_count = my_plugins.count() if query.id is not None: my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) if query.name is not None: @@ -101,7 +103,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]): my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size) result = my_plugins.all() session.close() - return result + total_pages = all_count // page_size + if all_count % page_size != 0: + total_pages += 1 + + + return result, total_pages, all_count + def count(self, query: MyPluginEntity): session = self.Session() diff --git a/pilot/base_modules/agent/db/plugin_hub_db.py b/pilot/base_modules/agent/db/plugin_hub_db.py index 1dfb2363d..b579f5071 100644 --- a/pilot/base_modules/agent/db/plugin_hub_db.py +++ b/pilot/base_modules/agent/db/plugin_hub_db.py @@ -13,12 +13,14 @@ class PluginHubEntity(Base): __tablename__ = 'plugin_hub' id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") name = Column(String, unique=True, nullable=False, comment="plugin name") + description = Column(String, nullable=False, comment="plugin description") author = Column(String, nullable=True, comment="plugin author") email = Column(String, nullable=True, comment="plugin author email") type = Column(String, comment="plugin type") version = Column(String, comment="plugin version") storage_channel = Column(String, comment="plugin storage channel") storage_url = Column(String, comment="plugin download url") + download_param = Column(String, comment="plugin download param") created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time") installed = Column(Boolean, default=False, comment="plugin already installed") @@ -61,6 +63,8 @@ class PluginHubDao(BaseDao[PluginHubEntity]): def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]: session = self.Session() plugin_hubs = session.query(PluginHubEntity) + all_count = plugin_hubs.count() + if query.id is not None: plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id) if query.name is not None: @@ -84,6 +88,20 @@ class PluginHubDao(BaseDao[PluginHubEntity]): plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size) result = plugin_hubs.all() session.close() + + total_pages = all_count // page_size + if all_count % page_size != 0: + total_pages += 1 + + + return result, total_pages, all_count + + def get_by_storage_url(self, storage_url): + session = self.Session() + plugin_hubs = session.query(PluginHubEntity) + plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url) + result = plugin_hubs.all() + session.close() return result def get_by_name(self, name: str) -> PluginHubEntity: diff --git a/pilot/base_modules/agent/hub/agent_hub.py b/pilot/base_modules/agent/hub/agent_hub.py index 55bc1c7a6..2d2daaf50 100644 --- a/pilot/base_modules/agent/hub/agent_hub.py +++ b/pilot/base_modules/agent/hub/agent_hub.py @@ -1,39 +1,54 @@ +import json import logging -import git import os +import glob +import zipfile +import requests +from pathlib import Path +import datetime +import shutil +from fastapi import UploadFile +from typing import Any +import tempfile from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao from ..db.my_plugin_db import MyPluginDao, MyPluginEntity from .schema import PluginStorageType +from ..plugins_util import scan_plugins, update_from_git logger = logging.getLogger("agent_hub") Default_User = "default" DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git" TEMP_PLUGIN_PATH = "" + class AgentHub: - def __init__(self, temp_hub_file_path:str = "") -> None: + def __init__(self, plugin_dir) -> None: self.hub_dao = PluginHubDao() self.my_lugin_dao = MyPluginDao() - if temp_hub_file_path: - self.temp_hub_file_path = temp_hub_file_path - else: - self.temp_hub_file_path = os.path.join(os.getcwd(), "plugins", "temp") + os.makedirs(plugin_dir, exist_ok=True) + self.plugin_dir = plugin_dir + self.temp_hub_file_path = os.path.join(plugin_dir, "temp") def install_plugin(self, plugin_name: str, user_name: str = None): logger.info(f"install_plugin {plugin_name}") - plugin_entity = self.hub_dao.get_by_name(plugin_name) if plugin_entity: if plugin_entity.storage_channel == PluginStorageType.Git.value: try: - self.__download_from_git(plugin_name, plugin_entity.storage_url) - self.load_plugin(plugin_name) + branch_name = None + authorization = None + if plugin_entity.download_param: + download_param = json.loads(plugin_entity.download_param) + branch_name = download_param.get("branch_name") + authorization = download_param.get("authorization") + file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization) # add to my plugins and edit hub status plugin_entity.installed = True my_plugin_entity = self.__build_my_plugin(plugin_entity) + my_plugin_entity.file_name = file_name if user_name: # TODO use user my_plugin_entity.user_code = "" @@ -55,9 +70,43 @@ class AgentHub: else: raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!") + else: raise ValueError(f"Can't Find Plugin {plugin_name}!") + def uninstall_plugin(self, plugin_name, user): + logger.info(f"uninstall_plugin:{plugin_name},{user}") + plugin_entity = self.hub_dao.get_by_name(plugin_name) + plugin_entity.installed = False + with self.hub_dao.Session() as session: + try: + my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name) + if user: + my_plugin_q.filter(MyPluginEntity.user_code == user) + my_plugin_q.delete() + session.merge(plugin_entity) + session.commit() + except: + session.rollback() + + # delete package file if not use + plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url) + have_installed = False + for plugin_info in plugin_infos: + if plugin_info.installed: + have_installed = True + break + if not have_installed: + plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1] + files = glob.glob( + os.path.join(self.plugin_dir, f"{plugin_repo_name}*") + ) + for file in files: + os.remove(file) + + def __download_from_git(self, github_repo, branch_name, authorization): + return update_from_git(self.plugin_dir, github_repo, branch_name, authorization) + def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity: my_plugin_entity = MyPluginEntity() my_plugin_entity.name = hub_plugin.name @@ -65,29 +114,68 @@ class AgentHub: my_plugin_entity.version = hub_plugin.version return my_plugin_entity - def __fetch_from_git(self): - logger.info("fetch plugins from git to local path:{}", self.temp_hub_file_path) - os.makedirs(self.temp_hub_file_path, exist_ok=True) - repo = git.Repo(self.temp_hub_file_path) - if repo.is_repo(): - repo.remotes.origin.pull() - else: - git.Repo.clone_from(DEFAULT_PLUGIN_REPO, self.temp_hub_file_path) + def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None): + logger.info("refresh_hub_by_git start!") + update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization) + git_plugins = scan_plugins(self.temp_hub_file_path) + for git_plugin in git_plugins: + old_hub_info = self.hub_dao.get_by_name(git_plugin._name) + if old_hub_info: + plugin_hub_info = old_hub_info + else: + plugin_hub_info = PluginHubEntity() + plugin_hub_info.type = "" + plugin_hub_info.storage_channel = PluginStorageType.Git.value + plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO + plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT') + plugin_hub_info.email = getattr(git_plugin, '_email', '') + plugin_hub_info.download_param = json.dumps({ + branch_name: branch_name, + authorization: authorization + }) + plugin_hub_info.installed = False - # if repo.head.is_valid(): - # clone succ, fetch plugins info + plugin_hub_info.name = git_plugin._name + plugin_hub_info.version = git_plugin._version + plugin_hub_info.description = git_plugin._description + self.hub_dao.update(plugin_hub_info) + + def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User): + + # We can not move temp file in windows system when we open file in context of `with` + file_path = os.path.join(self.plugin_dir, doc_file.filename) + if os.path.exists(file_path): + os.remove(file_path) + tmp_fd, tmp_path = tempfile.mkstemp( + dir=os.path.join(self.plugin_dir) + ) + with os.fdopen(tmp_fd, "wb") as tmp: + tmp.write(await doc_file.read()) + shutil.move( + tmp_path, + os.path.join(self.plugin_dir, doc_file.filename), + ) + + my_plugins = scan_plugins(self.plugin_dir, doc_file.filename) + for my_plugin in my_plugins: + my_plugin_entiy = MyPluginEntity() + + my_plugin_entiy.name = my_plugin._name + my_plugin_entiy.version = my_plugin._version + my_plugin_entiy.type = "Personal" + my_plugin_entiy.user_code = user + my_plugin_entiy.user_name = user + my_plugin_entiy.tenant = "" + my_plugin_entiy.file_name = doc_file.filename + + self.my_lugin_dao.update(my_plugin_entiy) - def upload_plugin_in_hub(self, name: str, path: str): - pass - def __download_from_git(self, plugin_name, url): - pass - - def load_plugin(self, plugin_name): - logger.info(f"load_plugin:{plugin_name}") - pass + def reload_my_plugins(self): + logger.info(f"load_plugins start!") + return scan_plugins(self.plugin_dir) def get_my_plugin(self, user: str): logger.info(f"get_my_plugin:{user}") @@ -95,7 +183,3 @@ class AgentHub: user = Default_User return self.my_lugin_dao.get_by_user(user) - def uninstall_plugin(self, plugin_name, user): - logger.info(f"uninstall_plugin:{plugin_name},{user}") - - pass diff --git a/pilot/base_modules/agent/model.py b/pilot/base_modules/agent/model.py new file mode 100644 index 000000000..2d268e138 --- /dev/null +++ b/pilot/base_modules/agent/model.py @@ -0,0 +1,30 @@ +from typing import TypedDict, Optional, Dict, List +from dataclasses import dataclass +from pydantic import BaseModel, Field +from typing import TypeVar, Generic, Any + +T = TypeVar('T') + +class PagenationFilter(Generic[T]): + page_index: int = 1 + page_size: int = 20 + filter: T = None + +class PagenationResult(Generic[T]): + page_index: int = 1 + page_size: int = 20 + total_page: int = 0 + total_row_count: int = 0 + datas: List[T] = [] + + + +@dataclass +class PluginHubParam: + channel: str = Field(..., description="Plugin storage channel") + url: str = Field(..., description="Plugin storage url") + + branch: str = Field(..., description="When the storage channel is github, use to specify the branch", nullable=True) + authorization: str = Field(..., description="github download authorization", nullable=True) + + diff --git a/pilot/base_modules/agent/module.py b/pilot/base_modules/agent/module.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/base_modules/agent/plugins_loader.py b/pilot/base_modules/agent/plugins_loader.py new file mode 100644 index 000000000..2d5c55a8e --- /dev/null +++ b/pilot/base_modules/agent/plugins_loader.py @@ -0,0 +1,2 @@ + +class PluginLoader(): diff --git a/pilot/base_modules/agent/plugins.py b/pilot/base_modules/agent/plugins_util.py similarity index 60% rename from pilot/base_modules/agent/plugins.py rename to pilot/base_modules/agent/plugins_util.py index f1734e43b..5a4f86229 100644 --- a/pilot/base_modules/agent/plugins.py +++ b/pilot/base_modules/agent/plugins_util.py @@ -5,6 +5,7 @@ import os import glob import zipfile import requests +import git import threading import datetime from pathlib import Path @@ -20,7 +21,6 @@ from pilot.configs.model_config import PLUGINS_DIR from pilot.logs import logger - def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]: """ Loader zip plugin file. Native support Auto_gpt_plugin @@ -106,7 +106,7 @@ def load_native_plugins(cfg: Config): with open(file_name, "wb") as f: f.write(response.content) print("save file") - cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) + cfg.set_plugins(scan_plugins(cfg.debug_mode)) else: print("get file faild,response code:", response.status_code) except Exception as e: @@ -116,7 +116,30 @@ def load_native_plugins(cfg: Config): t.start() -def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: +def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]: + logger.info(f"__scan_plugin_file:{file_path},{debug}") + loaded_plugins = [] + if moduleList := inspect_zip_for_modules(str(file_path), debug): + for module in moduleList: + plugin = Path(plugin) + module = Path(module) + logger.debug(f"Plugin: {plugin} Module: {module}") + zipped_package = zipimporter(str(plugin)) + zipped_module = zipped_package.load_module(str(module.parent)) + for key in dir(zipped_module): + if key.startswith("__"): + continue + a_module = getattr(zipped_module, key) + a_keys = dir(a_module) + if ( + "_abc_impl" in a_keys + and a_module.__name__ != "AutoGPTPluginTemplate" + # and denylist_allowlist_check(a_module.__name__, cfg) + ): + loaded_plugins.append(a_module()) + return loaded_plugins + +def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]: """Scan the plugins directory for plugins and loads them. Args: @@ -126,31 +149,16 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate Returns: List[Tuple[str, Path]]: List of plugins. """ - loaded_plugins = [] - current_dir = os.getcwd() - print(current_dir) - # Generic plugins - plugins_path_path = Path(PLUGINS_DIR) - for plugin in plugins_path_path.glob("*.zip"): - if moduleList := inspect_zip_for_modules(str(plugin), debug): - for module in moduleList: - plugin = Path(plugin) - module = Path(module) - logger.debug(f"Plugin: {plugin} Module: {module}") - zipped_package = zipimporter(str(plugin)) - zipped_module = zipped_package.load_module(str(module.parent)) - for key in dir(zipped_module): - if key.startswith("__"): - continue - a_module = getattr(zipped_module, key) - a_keys = dir(a_module) - if ( - "_abc_impl" in a_keys - and a_module.__name__ != "AutoGPTPluginTemplate" - # and denylist_allowlist_check(a_module.__name__, cfg) - ): - loaded_plugins.append(a_module()) + loaded_plugins = [] + # Generic plugins + plugins_path = Path(plugins_file_path) + if file_name: + plugin_path = Path(plugins_path, file_name) + loaded_plugins = __scan_plugin_file(plugin_path) + else: + for plugin_path in plugins_path.glob("*.zip"): + loaded_plugins.extend(__scan_plugin_file(plugin_path)) if loaded_plugins: logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------") @@ -183,5 +191,55 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool: return ack.lower() == cfg.authorise_key -if __name__ == '__main__': - print(inspect_zip_for_modules("/Users/tuyang.yhj/Downloads/DB-GPT-Plugins-main (1).zip")) +def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main", + authorization: str = "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"): + os.makedirs(download_path, exist_ok=True) + if github_repo: + if github_repo.index("github.com") <= 0: + raise ValueError("Not a correct Github repository address!" + github_repo) + github_repo = github_repo.replace(".git", "") + url = github_repo + "/archive/refs/heads/" + branch_name + ".zip" + plugin_repo_name = github_repo.strip('/').split('/')[-1] + else: + url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip" + plugin_repo_name = "DB-GPT-Plugins" + try: + session = requests.Session() + response = session.get( + url, + headers={"Authorization": authorization}, + ) + + if response.status_code == 200: + plugins_path_path = Path(download_path) + files = glob.glob( + os.path.join(plugins_path_path, f"{plugin_repo_name}*") + ) + for file in files: + os.remove(file) + now = datetime.datetime.now() + time_str = now.strftime("%Y%m%d%H%M%S") + file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip" + print(file_name) + with open(file_name, "wb") as f: + f.write(response.content) + return plugin_repo_name + else: + logger.error("update plugins faild,response code:", response.status_code) + raise ValueError("download plugin faild!" + response.status_code) + except Exception as e: + logger.error("update plugins from git exception!" + str(e)) + raise ValueError("download plugin exception!", e) + + +def __fetch_from_git(local_path, git_url): + logger.info("fetch plugins from git to local path:{}", local_path) + os.makedirs(local_path, exist_ok=True) + repo = git.Repo(local_path) + if repo.is_repo(): + repo.remotes.origin.pull() + else: + git.Repo.clone_from(git_url, local_path) + + # if repo.head.is_valid(): + # clone succ, fetch plugins info diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 162759e3c..b56176991 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -56,6 +56,13 @@ class ChatScene(Enum): param_types=["Plugin Select"], ) + ChatAgent = Scene( + code="chat_agent", + name="Agent Chat", + describe="Use tools through dialogue to accomplish your goals.", + param_types=["Plugin Select"], + ) + InnerChatDBSummary = Scene( "inner_chat_db_summary", "DB Summary", "Db Summary.", True ) diff --git a/pilot/scene/chat_agent/__init__.py b/pilot/scene/chat_agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py new file mode 100644 index 000000000..9a35c595b --- /dev/null +++ b/pilot/scene/chat_agent/chat.py @@ -0,0 +1,72 @@ +from typing import List, Dict + +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.configs.config import Config +from pilot.base_modules.agent.commands.command import execute_command +from pilot.base_modules.agent import PluginPromptGenerator + +CFG = Config() + + +class ChatWithPlugin(BaseChat): + chat_scene: str = ChatScene.ChatAgent.value() + plugins_prompt_generator: PluginPromptGenerator + select_plugin: str = None + + def __init__(self, chat_param: Dict): + self.plugin_selector = chat_param.select_param + chat_param["chat_mode"] = ChatScene.ChatAgent + super().__init__(chat_param=chat_param) + self.plugins_prompt_generator = PluginPromptGenerator() + self.plugins_prompt_generator.command_registry = CFG.command_registry + # 加载插件中可用命令 + self.select_plugin = self.plugin_selector + if self.select_plugin: + for plugin in CFG.plugins: + if plugin._name == self.plugin_selector: + if not plugin.can_handle_post_prompt(): + continue + self.plugins_prompt_generator = plugin.post_prompt( + self.plugins_prompt_generator + ) + + else: + for plugin in CFG.plugins: + if not plugin.can_handle_post_prompt(): + continue + self.plugins_prompt_generator = plugin.post_prompt( + self.plugins_prompt_generator + ) + + def generate_input_values(self): + input_values = { + "input": self.current_user_input, + "constraints": self.__list_to_prompt_str( + list(self.plugins_prompt_generator.constraints) + ), + "commands_infos": self.plugins_prompt_generator.generate_commands_string(), + } + return input_values + + def do_action(self, prompt_response): + print(f"do_action:{prompt_response}") + ## plugin command run + return execute_command( + str(prompt_response.command.get("name")), + prompt_response.command.get("args", {}), + self.plugins_prompt_generator, + ) + + def chat_show(self): + super().chat_show() + + def __list_to_prompt_str(self, list: List) -> str: + return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) + + def generate(self, p) -> str: + return super().generate(p) + + @property + def chat_type(self) -> str: + return ChatScene.ChatAgent.value diff --git a/pilot/scene/chat_agent/example.py b/pilot/scene/chat_agent/example.py new file mode 100644 index 000000000..d4b84d757 --- /dev/null +++ b/pilot/scene/chat_agent/example.py @@ -0,0 +1,23 @@ +from pilot.prompts.example_base import ExampleSelector + +## Two examples are defined by default +EXAMPLES = [ + { + "messages": [ + {"type": "human", "data": {"content": "查询xxx", "example": True}}, + { + "type": "ai", + "data": { + "content": """{ + \"thoughts\": \"thought text\", + \"speak\": \"thoughts summary to say to user\", + \"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}}, + }""", + "example": True, + }, + }, + ] + }, +] + +plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_agent/out_parser.py b/pilot/scene/chat_agent/out_parser.py new file mode 100644 index 000000000..2078018d0 --- /dev/null +++ b/pilot/scene/chat_agent/out_parser.py @@ -0,0 +1,46 @@ +import json +from typing import Dict, NamedTuple +from pilot.utils import build_logger +from pilot.out_parser.base import BaseOutputParser, T +from pilot.configs.model_config import LOGDIR + + +logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") + + +class PluginAction(NamedTuple): + command: Dict + speak: str = "" + thoughts: str = "" + + +class PluginChatOutputParser(BaseOutputParser): + def parse_prompt_response(self, model_out_text) -> T: + clean_json_str = super().parse_prompt_response(model_out_text) + print(clean_json_str) + if not clean_json_str: + raise ValueError("model server response not have json!") + try: + response = json.loads(clean_json_str) + except Exception as e: + raise ValueError("model server out not fllow the prompt!") + + speak = "" + thoughts = "" + for key in sorted(response): + if key.strip() == "command": + command = response[key] + if key.strip() == "thoughts": + thoughts = response[key] + if key.strip() == "speak": + speak = response[key] + return PluginAction(command, speak, thoughts) + + def parse_view_response(self, speak, data) -> str: + ### tool out data to table view + print(f"parse_view_response:{speak},{str(data)}") + view_text = f"##### {speak}" + "\n" + str(data) + return view_text + + def get_format_instructions(self) -> str: + pass diff --git a/pilot/scene/chat_agent/prompt.py b/pilot/scene/chat_agent/prompt.py new file mode 100644 index 000000000..b6be5a2c3 --- /dev/null +++ b/pilot/scene/chat_agent/prompt.py @@ -0,0 +1,78 @@ +import json +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.common.schema import SeparatorStyle, ExampleType + +from pilot.scene.chat_execution.out_parser import PluginChatOutputParser +from pilot.scene.chat_execution.example import plugin_example + +CFG = Config() + +_PROMPT_SCENE_DEFINE_EN = "You are a universal AI assistant." + +_DEFAULT_TEMPLATE_EN = """ +You need to analyze the user goals and, under the given constraints, prioritize using one of the following tools to solve the user goals. +Tool list: + {tool_list} +Constraint: + 1. After selecting an available tool, please ensure that the output results include the following parts to use the tool: + Selected Tool name Parameter valueParameter value 2 + 2. If you cannot analyze the exact tool for the problem, you can consider using the search engine tool among the tools first. + 3. Parameter content may need to be inferred based on the user's goals, not just extracted from text + 4. If you cannot find a suitable tool, please answer Unable to complete the goal. + {expand_constraints} +User goals: + {user_goal} +""" + +_PROMPT_SCENE_DEFINE_ZH = "你是一个通用AI助手!" + +_DEFAULT_TEMPLATE_ZH = """ +请一步步思考,如何在满足下面约束条件的前提下,回答或解决用户问题或目标。 +工具列表: + {tool_list} +约束条件: + 1. 找到可用的工具后,请确保输出结果包含以下内容用来使用工具:Selected Tool name Parameter valueParameter value 2 + 2.任务重可以使用多个工具,上面约束的方式生成每个工具的调用,对于工具使用的提示文本,需要在工具使用前生成 + 3.如果有多个工具被使用,后续工具需要第一个工具的结果作为参数的, 使用如下文本来替代参数值: + 4.如果对于问题无法理解和解决,可以考虑优先使用工具中的搜索引擎工具 + 5.参数内容可能需要根据用户的目标推理得到,不仅仅是从文本提取 + 6.如果中无法找到合适的工具,请回答无法完成目标。 + 7.约束条件和工具信息作为推理过程的辅助信息,不要表达在给用户的输出内容中 + {expand_constraints} +用户目标: + {user_goal} +""" + +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) + + +_PROMPT_SCENE_DEFINE=( + _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH +) + +RESPONSE_FORMAT = None + + +EXAMPLE_TYPE = ExampleType.ONE_SHOT +PROMPT_SEP = SeparatorStyle.SINGLE.value +### Whether the model service is streaming output +PROMPT_NEED_STREAM_OUT = True + +prompt = PromptTemplate( + template_scene=ChatScene.ChatAgent.value(), + input_variables=["tool_list", "expand_constraints", "user_goal"], + response_format=None, + template_define=_PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_STREAM_OUT, + output_parser=PluginChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT + ), + # example_selector=plugin_example, +) + +CFG.prompt_template_registry.register(prompt, is_default=True) diff --git a/pilot/scene/chat_execution/prompt_v2.py b/pilot/scene/chat_execution/prompt_v2.py deleted file mode 100644 index 8b1378917..000000000 --- a/pilot/scene/chat_execution/prompt_v2.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/pilot/server/base.py b/pilot/server/base.py index ec08fe074..d363e05bd 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -6,6 +6,7 @@ from typing import Optional, Any from dataclasses import dataclass, field from pilot.configs.config import Config +from pilot.configs.model_config import PLUGINS_DIR from pilot.componet import SystemApp from pilot.utils.parameter_utils import BaseParameters from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade @@ -30,7 +31,7 @@ def async_db_summery(system_app: SystemApp): def server_init(args, system_app: SystemApp): from pilot.base_modules.agent.commands.command_mange import CommandRegistry - from pilot.base_modules.agent.plugins import scan_plugins + from pilot.base_modules.agent.plugins_util import scan_plugins # logger.info(f"args: {args}") @@ -43,7 +44,7 @@ def server_init(args, system_app: SystemApp): # load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) - cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) + cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode)) # Loader plugins and commands command_categories = [ @@ -126,3 +127,4 @@ class WebWerverParameters(BaseParameters): }, ) light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"}) +