fix(ChatExcel): ChatExcel OutParse Bug Fix

1.ChatDashboard Display optimization
This commit is contained in:
yhjun1026
2023-09-19 14:56:34 +08:00
parent b3e31a84b1
commit 3cafee451e
18 changed files with 563 additions and 83 deletions

View File

@@ -1,13 +0,0 @@
from abc import ABC, abstractmethod
class AgentFacade(ABC):
def __init__(self) -> None:
self.model = None

View File

@@ -27,7 +27,6 @@ def execute_ai_response_json(
user_input: str = None, user_input: str = None,
) -> str: ) -> str:
""" """
Args: Args:
command_registry: command_registry:
ai_response: ai_response:
@@ -65,6 +64,8 @@ def execute_ai_response_json(
return result return result
def execute_command( def execute_command(
command_name: str, command_name: str,
arguments, arguments,

View File

@@ -3,6 +3,8 @@ import time
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Body, Body,
UploadFile,
File,
) )
from typing import List from typing import List
@@ -13,14 +15,77 @@ from pilot.openapi.api_view_model import (
Result, Result,
) )
from .model import PluginHubParam, PagenationFilter, PagenationResult
from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity
from .db.my_plugin_db import MyPluginEntity
from pilot.configs.model_config import PLUGINS_DIR
router = APIRouter() router = APIRouter()
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log") logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
@router.get("/v1/mange/agent/list", response_model=Result[str]) @router.post("/v1/agent/hub/update", response_model=Result[str])
async def get_agent_list(): async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"get_agent_list!") logger.info(f"agent_hub_update:{update_param.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization)
return Result.succ(None) return Result.succ(None)
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubEntity] = Body()):
logger.info(f"get_agent_list:{json.dumps(filter)}")
agent_hub = AgentHub(PLUGINS_DIR)
datas, total_pages, total_count = agent_hub.hub_dao.list(filter.filter, filter.page_index, filter.page_size)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result.page_index = filter.page_index
result.page_size = filter.page_size
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
return Result.succ(result)
@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user:str= None):
logger.info(f"my_agents:{json.dumps(my_agents)}")
agent_hub = AgentHub(PLUGINS_DIR)
return Result.succ(agent_hub.get_my_plugin(user))
@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name:str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.install_plugin(plugin_name, user)
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name:str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.uninstall_plugin(plugin_name, user)
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.upload_my_plugin(doc_file, user)
return Result.succ(None)
except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")

View File

@@ -18,6 +18,7 @@ class MyPluginEntity(Base):
user_code = Column(String, nullable=True, comment="user code") user_code = Column(String, nullable=True, comment="user code")
user_name = Column(String, nullable=True, comment="user name") user_name = Column(String, nullable=True, comment="user name")
name = Column(String, unique=True, nullable=False, comment="plugin name") name = Column(String, unique=True, nullable=False, comment="plugin name")
file_name = Column(String, nullable=False, comment="plugin package file name")
type = Column(String, comment="plugin type") type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version") version = Column(String, comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count") use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
@@ -74,6 +75,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]: def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
session = self.Session() session = self.Session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None: if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None: if query.name is not None:
@@ -101,7 +103,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size) my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
result = my_plugins.all() result = my_plugins.all()
session.close() session.close()
return result total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def count(self, query: MyPluginEntity): def count(self, query: MyPluginEntity):
session = self.Session() session = self.Session()

View File

@@ -13,12 +13,14 @@ class PluginHubEntity(Base):
__tablename__ = 'plugin_hub' __tablename__ = 'plugin_hub'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id") id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
name = Column(String, unique=True, nullable=False, comment="plugin name") name = Column(String, unique=True, nullable=False, comment="plugin name")
description = Column(String, nullable=False, comment="plugin description")
author = Column(String, nullable=True, comment="plugin author") author = Column(String, nullable=True, comment="plugin author")
email = Column(String, nullable=True, comment="plugin author email") email = Column(String, nullable=True, comment="plugin author email")
type = Column(String, comment="plugin type") type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version") version = Column(String, comment="plugin version")
storage_channel = Column(String, comment="plugin storage channel") storage_channel = Column(String, comment="plugin storage channel")
storage_url = Column(String, comment="plugin download url") storage_url = Column(String, comment="plugin download url")
download_param = Column(String, comment="plugin download param")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time") created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
installed = Column(Boolean, default=False, comment="plugin already installed") installed = Column(Boolean, default=False, comment="plugin already installed")
@@ -61,6 +63,8 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]: def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
session = self.Session() session = self.Session()
plugin_hubs = session.query(PluginHubEntity) plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()
if query.id is not None: if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id) plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None: if query.name is not None:
@@ -84,6 +88,20 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size) plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
result = plugin_hubs.all() result = plugin_hubs.all()
session.close() session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def get_by_storage_url(self, storage_url):
session = self.Session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all()
session.close()
return result return result
def get_by_name(self, name: str) -> PluginHubEntity: def get_by_name(self, name: str) -> PluginHubEntity:

View File

@@ -1,39 +1,54 @@
import json
import logging import logging
import git
import os import os
import glob
import zipfile
import requests
from pathlib import Path
import datetime
import shutil
from fastapi import UploadFile
from typing import Any
import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from .schema import PluginStorageType from .schema import PluginStorageType
from ..plugins_util import scan_plugins, update_from_git
logger = logging.getLogger("agent_hub") logger = logging.getLogger("agent_hub")
Default_User = "default" Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git" DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = "" TEMP_PLUGIN_PATH = ""
class AgentHub: class AgentHub:
def __init__(self, temp_hub_file_path:str = "") -> None: def __init__(self, plugin_dir) -> None:
self.hub_dao = PluginHubDao() self.hub_dao = PluginHubDao()
self.my_lugin_dao = MyPluginDao() self.my_lugin_dao = MyPluginDao()
if temp_hub_file_path: os.makedirs(plugin_dir, exist_ok=True)
self.temp_hub_file_path = temp_hub_file_path self.plugin_dir = plugin_dir
else: self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
self.temp_hub_file_path = os.path.join(os.getcwd(), "plugins", "temp")
def install_plugin(self, plugin_name: str, user_name: str = None): def install_plugin(self, plugin_name: str, user_name: str = None):
logger.info(f"install_plugin {plugin_name}") logger.info(f"install_plugin {plugin_name}")
plugin_entity = self.hub_dao.get_by_name(plugin_name) plugin_entity = self.hub_dao.get_by_name(plugin_name)
if plugin_entity: if plugin_entity:
if plugin_entity.storage_channel == PluginStorageType.Git.value: if plugin_entity.storage_channel == PluginStorageType.Git.value:
try: try:
self.__download_from_git(plugin_name, plugin_entity.storage_url) branch_name = None
self.load_plugin(plugin_name) authorization = None
if plugin_entity.download_param:
download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization")
file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization)
# add to my plugins and edit hub status # add to my plugins and edit hub status
plugin_entity.installed = True plugin_entity.installed = True
my_plugin_entity = self.__build_my_plugin(plugin_entity) my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name
if user_name: if user_name:
# TODO use user # TODO use user
my_plugin_entity.user_code = "" my_plugin_entity.user_code = ""
@@ -55,9 +70,43 @@ class AgentHub:
else: else:
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!") raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
else: else:
raise ValueError(f"Can't Find Plugin {plugin_name}!") raise ValueError(f"Can't Find Plugin {plugin_name}!")
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
plugin_entity.installed = False
with self.hub_dao.Session() as session:
try:
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
# delete package file if not use
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
have_installed = False
for plugin_info in plugin_infos:
if plugin_info.installed:
have_installed = True
break
if not have_installed:
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
files = glob.glob(
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
)
for file in files:
os.remove(file)
def __download_from_git(self, github_repo, branch_name, authorization):
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity: def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
my_plugin_entity = MyPluginEntity() my_plugin_entity = MyPluginEntity()
my_plugin_entity.name = hub_plugin.name my_plugin_entity.name = hub_plugin.name
@@ -65,29 +114,68 @@ class AgentHub:
my_plugin_entity.version = hub_plugin.version my_plugin_entity.version = hub_plugin.version
return my_plugin_entity return my_plugin_entity
def __fetch_from_git(self): def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None):
logger.info("fetch plugins from git to local path:{}", self.temp_hub_file_path) logger.info("refresh_hub_by_git start!")
os.makedirs(self.temp_hub_file_path, exist_ok=True) update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization)
repo = git.Repo(self.temp_hub_file_path) git_plugins = scan_plugins(self.temp_hub_file_path)
if repo.is_repo(): for git_plugin in git_plugins:
repo.remotes.origin.pull() old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
else: if old_hub_info:
git.Repo.clone_from(DEFAULT_PLUGIN_REPO, self.temp_hub_file_path) plugin_hub_info = old_hub_info
else:
plugin_hub_info = PluginHubEntity()
plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT')
plugin_hub_info.email = getattr(git_plugin, '_email', '')
plugin_hub_info.download_param = json.dumps({
branch_name: branch_name,
authorization: authorization
})
plugin_hub_info.installed = False
# if repo.head.is_valid(): plugin_hub_info.name = git_plugin._name
# clone succ fetch plugins info plugin_hub_info.version = git_plugin._version
plugin_hub_info.description = git_plugin._description
self.hub_dao.update(plugin_hub_info)
def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User):
# We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path):
os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp(
dir=os.path.join(self.plugin_dir)
)
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
tmp_path,
os.path.join(self.plugin_dir, doc_file.filename),
)
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
for my_plugin in my_plugins:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal"
my_plugin_entiy.user_code = user
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_lugin_dao.update(my_plugin_entiy)
def upload_plugin_in_hub(self, name: str, path: str):
pass
def __download_from_git(self, plugin_name, url): def reload_my_plugins(self):
pass logger.info(f"load_plugins start!")
return scan_plugins(self.plugin_dir)
def load_plugin(self, plugin_name):
logger.info(f"load_plugin:{plugin_name}")
pass
def get_my_plugin(self, user: str): def get_my_plugin(self, user: str):
logger.info(f"get_my_plugin:{user}") logger.info(f"get_my_plugin:{user}")
@@ -95,7 +183,3 @@ class AgentHub:
user = Default_User user = Default_User
return self.my_lugin_dao.get_by_user(user) return self.my_lugin_dao.get_by_user(user)
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
pass

View File

@@ -0,0 +1,30 @@
from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
T = TypeVar('T')
class PagenationFilter(Generic[T]):
page_index: int = 1
page_size: int = 20
filter: T = None
class PagenationResult(Generic[T]):
page_index: int = 1
page_size: int = 20
total_page: int = 0
total_row_count: int = 0
datas: List[T] = []
@dataclass
class PluginHubParam:
channel: str = Field(..., description="Plugin storage channel")
url: str = Field(..., description="Plugin storage url")
branch: str = Field(..., description="When the storage channel is github, use to specify the branch", nullable=True)
authorization: str = Field(..., description="github download authorization", nullable=True)

View File

View File

@@ -0,0 +1,2 @@
class PluginLoader():

View File

@@ -5,6 +5,7 @@ import os
import glob import glob
import zipfile import zipfile
import requests import requests
import git
import threading import threading
import datetime import datetime
from pathlib import Path from pathlib import Path
@@ -20,7 +21,6 @@ from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger from pilot.logs import logger
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]: def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
""" """
Loader zip plugin file. Native support Auto_gpt_plugin Loader zip plugin file. Native support Auto_gpt_plugin
@@ -106,7 +106,7 @@ def load_native_plugins(cfg: Config):
with open(file_name, "wb") as f: with open(file_name, "wb") as f:
f.write(response.content) f.write(response.content)
print("save file") print("save file")
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) cfg.set_plugins(scan_plugins(cfg.debug_mode))
else: else:
print("get file faildresponse code", response.status_code) print("get file faildresponse code", response.status_code)
except Exception as e: except Exception as e:
@@ -116,7 +116,30 @@ def load_native_plugins(cfg: Config):
t.start() t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]:
logger.info(f"__scan_plugin_file:{file_path},{debug}")
loaded_plugins = []
if moduleList := inspect_zip_for_modules(str(file_path), debug):
for module in moduleList:
plugin = Path(plugin)
module = Path(module)
logger.debug(f"Plugin: {plugin} Module: {module}")
zipped_package = zipimporter(str(plugin))
zipped_module = zipped_package.load_module(str(module.parent))
for key in dir(zipped_module):
if key.startswith("__"):
continue
a_module = getattr(zipped_module, key)
a_keys = dir(a_module)
if (
"_abc_impl" in a_keys
and a_module.__name__ != "AutoGPTPluginTemplate"
# and denylist_allowlist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
return loaded_plugins
def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them. """Scan the plugins directory for plugins and loads them.
Args: Args:
@@ -126,31 +149,16 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
Returns: Returns:
List[Tuple[str, Path]]: List of plugins. List[Tuple[str, Path]]: List of plugins.
""" """
loaded_plugins = []
current_dir = os.getcwd()
print(current_dir)
# Generic plugins
plugins_path_path = Path(PLUGINS_DIR)
for plugin in plugins_path_path.glob("*.zip"): loaded_plugins = []
if moduleList := inspect_zip_for_modules(str(plugin), debug): # Generic plugins
for module in moduleList: plugins_path = Path(plugins_file_path)
plugin = Path(plugin) if file_name:
module = Path(module) plugin_path = Path(plugins_path, file_name)
logger.debug(f"Plugin: {plugin} Module: {module}") loaded_plugins = __scan_plugin_file(plugin_path)
zipped_package = zipimporter(str(plugin)) else:
zipped_module = zipped_package.load_module(str(module.parent)) for plugin_path in plugins_path.glob("*.zip"):
for key in dir(zipped_module): loaded_plugins.extend(__scan_plugin_file(plugin_path))
if key.startswith("__"):
continue
a_module = getattr(zipped_module, key)
a_keys = dir(a_module)
if (
"_abc_impl" in a_keys
and a_module.__name__ != "AutoGPTPluginTemplate"
# and denylist_allowlist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
if loaded_plugins: if loaded_plugins:
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------") logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
@@ -183,5 +191,55 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
return ack.lower() == cfg.authorise_key return ack.lower() == cfg.authorise_key
if __name__ == '__main__': def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main",
print(inspect_zip_for_modules("/Users/tuyang.yhj/Downloads/DB-GPT-Plugins-main (1).zip")) authorization: str = "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"):
os.makedirs(download_path, exist_ok=True)
if github_repo:
if github_repo.index("github.com") <= 0:
raise ValueError("Not a correct Github repository address" + github_repo)
github_repo = github_repo.replace(".git", "")
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
plugin_repo_name = github_repo.strip('/').split('/')[-1]
else:
url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
plugin_repo_name = "DB-GPT-Plugins"
try:
session = requests.Session()
response = session.get(
url,
headers={"Authorization": authorization},
)
if response.status_code == 200:
plugins_path_path = Path(download_path)
files = glob.glob(
os.path.join(plugins_path_path, f"{plugin_repo_name}*")
)
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
print(file_name)
with open(file_name, "wb") as f:
f.write(response.content)
return plugin_repo_name
else:
logger.error("update plugins faildresponse code", response.status_code)
raise ValueError("download plugin faild!" + response.status_code)
except Exception as e:
logger.error("update plugins from git exception!" + str(e))
raise ValueError("download plugin exception!", e)
def __fetch_from_git(local_path, git_url):
logger.info("fetch plugins from git to local path:{}", local_path)
os.makedirs(local_path, exist_ok=True)
repo = git.Repo(local_path)
if repo.is_repo():
repo.remotes.origin.pull()
else:
git.Repo.clone_from(git_url, local_path)
# if repo.head.is_valid():
# clone succ fetch plugins info

View File

@@ -56,6 +56,13 @@ class ChatScene(Enum):
param_types=["Plugin Select"], param_types=["Plugin Select"],
) )
ChatAgent = Scene(
code="chat_agent",
name="Agent Chat",
describe="Use tools through dialogue to accomplish your goals.",
param_types=["Plugin Select"],
)
InnerChatDBSummary = Scene( InnerChatDBSummary = Scene(
"inner_chat_db_summary", "DB Summary", "Db Summary.", True "inner_chat_db_summary", "DB Summary", "Db Summary.", True
) )

View File

View File

@@ -0,0 +1,72 @@
from typing import List, Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent import PluginPromptGenerator
CFG = Config()
class ChatWithPlugin(BaseChat):
chat_scene: str = ChatScene.ChatAgent.value()
plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None
def __init__(self, chat_param: Dict):
self.plugin_selector = chat_param.select_param
chat_param["chat_mode"] = ChatScene.ChatAgent
super().__init__(chat_param=chat_param)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令
self.select_plugin = self.plugin_selector
if self.select_plugin:
for plugin in CFG.plugins:
if plugin._name == self.plugin_selector:
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
else:
for plugin in CFG.plugins:
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"constraints": self.__list_to_prompt_str(
list(self.plugins_prompt_generator.constraints)
),
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
}
return input_values
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
## plugin command run
return execute_command(
str(prompt_response.command.get("name")),
prompt_response.command.get("args", {}),
self.plugins_prompt_generator,
)
def chat_show(self):
super().chat_show()
def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
def generate(self, p) -> str:
return super().generate(p)
@property
def chat_type(self) -> str:
return ChatScene.ChatAgent.value

View File

@@ -0,0 +1,23 @@
from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
{
"messages": [
{"type": "human", "data": {"content": "查询xxx", "example": True}},
{
"type": "ai",
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"speak\": \"thoughts summary to say to user\",
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
},
},
]
},
]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@@ -0,0 +1,46 @@
import json
from typing import Dict, NamedTuple
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class PluginAction(NamedTuple):
command: Dict
speak: str = ""
thoughts: str = ""
class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
clean_json_str = super().parse_prompt_response(model_out_text)
print(clean_json_str)
if not clean_json_str:
raise ValueError("model server response not have json!")
try:
response = json.loads(clean_json_str)
except Exception as e:
raise ValueError("model server out not fllow the prompt!")
speak = ""
thoughts = ""
for key in sorted(response):
if key.strip() == "command":
command = response[key]
if key.strip() == "thoughts":
thoughts = response[key]
if key.strip() == "speak":
speak = response[key]
return PluginAction(command, speak, thoughts)
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}")
view_text = f"##### {speak}" + "\n" + str(data)
return view_text
def get_format_instructions(self) -> str:
pass

View File

@@ -0,0 +1,78 @@
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle, ExampleType
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
from pilot.scene.chat_execution.example import plugin_example
CFG = Config()
_PROMPT_SCENE_DEFINE_EN = "You are a universal AI assistant."
_DEFAULT_TEMPLATE_EN = """
You need to analyze the user goals and, under the given constraints, prioritize using one of the following tools to solve the user goals.
Tool list:
{tool_list}
Constraint:
1. After selecting an available tool, please ensure that the output results include the following parts to use the tool:
<api-call><name>Selected Tool name</name> <arg1>Parameter value</arg1><arg2 >Parameter value 2</arg2></api-call>
2. If you cannot analyze the exact tool for the problem, you can consider using the search engine tool among the tools first.
3. Parameter content may need to be inferred based on the user's goals, not just extracted from text
4. If you cannot find a suitable tool, please answer Unable to complete the goal.
{expand_constraints}
User goals:
{user_goal}
"""
_PROMPT_SCENE_DEFINE_ZH = "你是一个通用AI助手"
_DEFAULT_TEMPLATE_ZH = """
请一步步思考,如何在满足下面约束条件的前提下,回答或解决用户问题或目标。
工具列表:
{tool_list}
约束条件:
1. 找到可用的工具后,请确保输出结果包含以下内容用来使用工具:<api-call><name>Selected Tool name</name> <arg1>Parameter value</arg1><arg2 >Parameter value 2</arg2></api-call>
2.任务重可以使用多个工具,上面约束的方式生成每个工具的调用,对于工具使用的提示文本,需要在工具使用前生成
3.如果有多个工具被使用,后续工具需要第一个工具的结果作为参数的, 使用如下文本来替代参数值:<api1-result>
4.如果对于问题无法理解和解决,可以考虑优先使用工具中的搜索引擎工具
5.参数内容可能需要根据用户的目标推理得到,不仅仅是从文本提取
6.如果中无法找到合适的工具,请回答无法完成目标。
7.约束条件和工具信息作为推理过程的辅助信息,不要表达在给用户的输出内容中
{expand_constraints}
用户目标:
{user_goal}
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
_PROMPT_SCENE_DEFINE=(
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
RESPONSE_FORMAT = None
EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
PROMPT_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatAgent.value(),
input_variables=["tool_list", "expand_constraints", "user_goal"],
response_format=None,
template_define=_PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
# example_selector=plugin_example,
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -1 +0,0 @@

View File

@@ -6,6 +6,7 @@ from typing import Optional, Any
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR
from pilot.componet import SystemApp from pilot.componet import SystemApp
from pilot.utils.parameter_utils import BaseParameters from pilot.utils.parameter_utils import BaseParameters
from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade
@@ -30,7 +31,7 @@ def async_db_summery(system_app: SystemApp):
def server_init(args, system_app: SystemApp): def server_init(args, system_app: SystemApp):
from pilot.base_modules.agent.commands.command_mange import CommandRegistry from pilot.base_modules.agent.commands.command_mange import CommandRegistry
from pilot.base_modules.agent.plugins import scan_plugins from pilot.base_modules.agent.plugins_util import scan_plugins
# logger.info(f"args: {args}") # logger.info(f"args: {args}")
@@ -43,7 +44,7 @@ def server_init(args, system_app: SystemApp):
# load_native_plugins(cfg) # load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode))
# Loader plugins and commands # Loader plugins and commands
command_categories = [ command_categories = [
@@ -126,3 +127,4 @@ class WebWerverParameters(BaseParameters):
}, },
) )
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"}) light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})