fix(ChatExcel): ChatExcel OutParse Bug Fix

1.ChatDashboard Display optimization
This commit is contained in:
yhjun1026
2023-09-19 14:56:34 +08:00
parent b3e31a84b1
commit 3cafee451e
18 changed files with 563 additions and 83 deletions

View File

@@ -1,13 +0,0 @@
from abc import ABC, abstractmethod
class AgentFacade(ABC):
def __init__(self) -> None:
self.model = None

View File

@@ -27,7 +27,6 @@ def execute_ai_response_json(
user_input: str = None,
) -> str:
"""
Args:
command_registry:
ai_response:
@@ -65,6 +64,8 @@ def execute_ai_response_json(
return result
def execute_command(
command_name: str,
arguments,

View File

@@ -3,6 +3,8 @@ import time
from fastapi import (
APIRouter,
Body,
UploadFile,
File,
)
from typing import List
@@ -13,14 +15,77 @@ from pilot.openapi.api_view_model import (
Result,
)
from .model import PluginHubParam, PagenationFilter, PagenationResult
from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity
from .db.my_plugin_db import MyPluginEntity
from pilot.configs.model_config import PLUGINS_DIR
router = APIRouter()
logger = build_logger("agent_mange", LOGDIR + "agent_mange.log")
@router.get("/v1/mange/agent/list", response_model=Result[str])
async def get_agent_list():
logger.info(f"get_agent_list!")
@router.post("/v1/agent/hub/update", response_model=Result[str])
async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization)
return Result.succ(None)
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubEntity] = Body()):
logger.info(f"get_agent_list:{json.dumps(filter)}")
agent_hub = AgentHub(PLUGINS_DIR)
datas, total_pages, total_count = agent_hub.hub_dao.list(filter.filter, filter.page_index, filter.page_size)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result.page_index = filter.page_index
result.page_size = filter.page_size
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
return Result.succ(result)
@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user:str= None):
logger.info(f"my_agents:{json.dumps(my_agents)}")
agent_hub = AgentHub(PLUGINS_DIR)
return Result.succ(agent_hub.get_my_plugin(user))
@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name:str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.install_plugin(plugin_name, user)
return Result.succ(None)
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name:str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.uninstall_plugin(plugin_name, user)
return Result.succ(None)
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.upload_my_plugin(doc_file, user)
return Result.succ(None)
except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")

View File

@@ -18,6 +18,7 @@ class MyPluginEntity(Base):
user_code = Column(String, nullable=True, comment="user code")
user_name = Column(String, nullable=True, comment="user name")
name = Column(String, unique=True, nullable=False, comment="plugin name")
file_name = Column(String, nullable=False, comment="plugin package file name")
type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
@@ -74,6 +75,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
session = self.Session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
@@ -101,7 +103,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
result = my_plugins.all()
session.close()
return result
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def count(self, query: MyPluginEntity):
session = self.Session()

View File

@@ -13,12 +13,14 @@ class PluginHubEntity(Base):
__tablename__ = 'plugin_hub'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
name = Column(String, unique=True, nullable=False, comment="plugin name")
description = Column(String, nullable=False, comment="plugin description")
author = Column(String, nullable=True, comment="plugin author")
email = Column(String, nullable=True, comment="plugin author email")
type = Column(String, comment="plugin type")
version = Column(String, comment="plugin version")
storage_channel = Column(String, comment="plugin storage channel")
storage_url = Column(String, comment="plugin download url")
download_param = Column(String, comment="plugin download param")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin upload time")
installed = Column(Boolean, default=False, comment="plugin already installed")
@@ -61,6 +63,8 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
session = self.Session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
@@ -84,6 +88,20 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
plugin_hubs = plugin_hubs.offset((page - 1) * page_size).limit(page_size)
result = plugin_hubs.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def get_by_storage_url(self, storage_url):
session = self.Session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all()
session.close()
return result
def get_by_name(self, name: str) -> PluginHubEntity:

View File

@@ -1,39 +1,54 @@
import json
import logging
import git
import os
import glob
import zipfile
import requests
from pathlib import Path
import datetime
import shutil
from fastapi import UploadFile
from typing import Any
import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
from ..db.my_plugin_db import MyPluginDao, MyPluginEntity
from .schema import PluginStorageType
from ..plugins_util import scan_plugins, update_from_git
logger = logging.getLogger("agent_hub")
Default_User = "default"
DEFAULT_PLUGIN_REPO = "https://github.com/eosphoros-ai/DB-GPT-Plugins.git"
TEMP_PLUGIN_PATH = ""
class AgentHub:
def __init__(self, temp_hub_file_path:str = "") -> None:
def __init__(self, plugin_dir) -> None:
self.hub_dao = PluginHubDao()
self.my_lugin_dao = MyPluginDao()
if temp_hub_file_path:
self.temp_hub_file_path = temp_hub_file_path
else:
self.temp_hub_file_path = os.path.join(os.getcwd(), "plugins", "temp")
os.makedirs(plugin_dir, exist_ok=True)
self.plugin_dir = plugin_dir
self.temp_hub_file_path = os.path.join(plugin_dir, "temp")
def install_plugin(self, plugin_name: str, user_name: str = None):
logger.info(f"install_plugin {plugin_name}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
if plugin_entity:
if plugin_entity.storage_channel == PluginStorageType.Git.value:
try:
self.__download_from_git(plugin_name, plugin_entity.storage_url)
self.load_plugin(plugin_name)
branch_name = None
authorization = None
if plugin_entity.download_param:
download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization")
file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization)
# add to my plugins and edit hub status
plugin_entity.installed = True
my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name
if user_name:
# TODO use user
my_plugin_entity.user_code = ""
@@ -55,9 +70,43 @@ class AgentHub:
else:
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
else:
raise ValueError(f"Can't Find Plugin {plugin_name}!")
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
plugin_entity = self.hub_dao.get_by_name(plugin_name)
plugin_entity.installed = False
with self.hub_dao.Session() as session:
try:
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
# delete package file if not use
plugin_infos = self.hub_dao.get_by_storage_url(plugin_entity.storage_url)
have_installed = False
for plugin_info in plugin_infos:
if plugin_info.installed:
have_installed = True
break
if not have_installed:
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
files = glob.glob(
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
)
for file in files:
os.remove(file)
def __download_from_git(self, github_repo, branch_name, authorization):
return update_from_git(self.plugin_dir, github_repo, branch_name, authorization)
def __build_my_plugin(self, hub_plugin: PluginHubEntity) -> MyPluginEntity:
my_plugin_entity = MyPluginEntity()
my_plugin_entity.name = hub_plugin.name
@@ -65,29 +114,68 @@ class AgentHub:
my_plugin_entity.version = hub_plugin.version
return my_plugin_entity
def __fetch_from_git(self):
logger.info("fetch plugins from git to local path:{}", self.temp_hub_file_path)
os.makedirs(self.temp_hub_file_path, exist_ok=True)
repo = git.Repo(self.temp_hub_file_path)
if repo.is_repo():
repo.remotes.origin.pull()
def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None):
logger.info("refresh_hub_by_git start!")
update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization)
git_plugins = scan_plugins(self.temp_hub_file_path)
for git_plugin in git_plugins:
old_hub_info = self.hub_dao.get_by_name(git_plugin._name)
if old_hub_info:
plugin_hub_info = old_hub_info
else:
git.Repo.clone_from(DEFAULT_PLUGIN_REPO, self.temp_hub_file_path)
plugin_hub_info = PluginHubEntity()
plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT')
plugin_hub_info.email = getattr(git_plugin, '_email', '')
plugin_hub_info.download_param = json.dumps({
branch_name: branch_name,
authorization: authorization
})
plugin_hub_info.installed = False
# if repo.head.is_valid():
# clone succ fetch plugins info
plugin_hub_info.name = git_plugin._name
plugin_hub_info.version = git_plugin._version
plugin_hub_info.description = git_plugin._description
self.hub_dao.update(plugin_hub_info)
def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User):
# We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path):
os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp(
dir=os.path.join(self.plugin_dir)
)
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
tmp_path,
os.path.join(self.plugin_dir, doc_file.filename),
)
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
for my_plugin in my_plugins:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal"
my_plugin_entiy.user_code = user
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_lugin_dao.update(my_plugin_entiy)
def upload_plugin_in_hub(self, name: str, path: str):
pass
def __download_from_git(self, plugin_name, url):
pass
def load_plugin(self, plugin_name):
logger.info(f"load_plugin:{plugin_name}")
pass
def reload_my_plugins(self):
logger.info(f"load_plugins start!")
return scan_plugins(self.plugin_dir)
def get_my_plugin(self, user: str):
logger.info(f"get_my_plugin:{user}")
@@ -95,7 +183,3 @@ class AgentHub:
user = Default_User
return self.my_lugin_dao.get_by_user(user)
def uninstall_plugin(self, plugin_name, user):
logger.info(f"uninstall_plugin:{plugin_name},{user}")
pass

View File

@@ -0,0 +1,30 @@
from typing import TypedDict, Optional, Dict, List
from dataclasses import dataclass
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
T = TypeVar('T')
class PagenationFilter(Generic[T]):
page_index: int = 1
page_size: int = 20
filter: T = None
class PagenationResult(Generic[T]):
page_index: int = 1
page_size: int = 20
total_page: int = 0
total_row_count: int = 0
datas: List[T] = []
@dataclass
class PluginHubParam:
channel: str = Field(..., description="Plugin storage channel")
url: str = Field(..., description="Plugin storage url")
branch: str = Field(..., description="When the storage channel is github, use to specify the branch", nullable=True)
authorization: str = Field(..., description="github download authorization", nullable=True)

View File

View File

@@ -0,0 +1,2 @@
class PluginLoader():

View File

@@ -5,6 +5,7 @@ import os
import glob
import zipfile
import requests
import git
import threading
import datetime
from pathlib import Path
@@ -20,7 +21,6 @@ from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
"""
Loader zip plugin file. Native support Auto_gpt_plugin
@@ -106,7 +106,7 @@ def load_native_plugins(cfg: Config):
with open(file_name, "wb") as f:
f.write(response.content)
print("save file")
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
cfg.set_plugins(scan_plugins(cfg.debug_mode))
else:
print("get file faildresponse code", response.status_code)
except Exception as e:
@@ -116,24 +116,10 @@ def load_native_plugins(cfg: Config):
t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.
Args:
cfg (Config): Config instance including plugins config
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
List[Tuple[str, Path]]: List of plugins.
"""
def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]:
logger.info(f"__scan_plugin_file:{file_path},{debug}")
loaded_plugins = []
current_dir = os.getcwd()
print(current_dir)
# Generic plugins
plugins_path_path = Path(PLUGINS_DIR)
for plugin in plugins_path_path.glob("*.zip"):
if moduleList := inspect_zip_for_modules(str(plugin), debug):
if moduleList := inspect_zip_for_modules(str(file_path), debug):
for module in moduleList:
plugin = Path(plugin)
module = Path(module)
@@ -151,6 +137,28 @@ def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate
# and denylist_allowlist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
return loaded_plugins
def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.
Args:
cfg (Config): Config instance including plugins config
debug (bool, optional): Enable debug logging. Defaults to False.
Returns:
List[Tuple[str, Path]]: List of plugins.
"""
loaded_plugins = []
# Generic plugins
plugins_path = Path(plugins_file_path)
if file_name:
plugin_path = Path(plugins_path, file_name)
loaded_plugins = __scan_plugin_file(plugin_path)
else:
for plugin_path in plugins_path.glob("*.zip"):
loaded_plugins.extend(__scan_plugin_file(plugin_path))
if loaded_plugins:
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
@@ -183,5 +191,55 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
return ack.lower() == cfg.authorise_key
if __name__ == '__main__':
print(inspect_zip_for_modules("/Users/tuyang.yhj/Downloads/DB-GPT-Plugins-main (1).zip"))
def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main",
authorization: str = "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"):
os.makedirs(download_path, exist_ok=True)
if github_repo:
if github_repo.index("github.com") <= 0:
raise ValueError("Not a correct Github repository address" + github_repo)
github_repo = github_repo.replace(".git", "")
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
plugin_repo_name = github_repo.strip('/').split('/')[-1]
else:
url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
plugin_repo_name = "DB-GPT-Plugins"
try:
session = requests.Session()
response = session.get(
url,
headers={"Authorization": authorization},
)
if response.status_code == 200:
plugins_path_path = Path(download_path)
files = glob.glob(
os.path.join(plugins_path_path, f"{plugin_repo_name}*")
)
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
print(file_name)
with open(file_name, "wb") as f:
f.write(response.content)
return plugin_repo_name
else:
logger.error("update plugins faildresponse code", response.status_code)
raise ValueError("download plugin faild!" + response.status_code)
except Exception as e:
logger.error("update plugins from git exception!" + str(e))
raise ValueError("download plugin exception!", e)
def __fetch_from_git(local_path, git_url):
logger.info("fetch plugins from git to local path:{}", local_path)
os.makedirs(local_path, exist_ok=True)
repo = git.Repo(local_path)
if repo.is_repo():
repo.remotes.origin.pull()
else:
git.Repo.clone_from(git_url, local_path)
# if repo.head.is_valid():
# clone succ fetch plugins info

View File

@@ -56,6 +56,13 @@ class ChatScene(Enum):
param_types=["Plugin Select"],
)
ChatAgent = Scene(
code="chat_agent",
name="Agent Chat",
describe="Use tools through dialogue to accomplish your goals.",
param_types=["Plugin Select"],
)
InnerChatDBSummary = Scene(
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
)

View File

View File

@@ -0,0 +1,72 @@
from typing import List, Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent import PluginPromptGenerator
CFG = Config()
class ChatWithPlugin(BaseChat):
chat_scene: str = ChatScene.ChatAgent.value()
plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None
def __init__(self, chat_param: Dict):
self.plugin_selector = chat_param.select_param
chat_param["chat_mode"] = ChatScene.ChatAgent
super().__init__(chat_param=chat_param)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry
# 加载插件中可用命令
self.select_plugin = self.plugin_selector
if self.select_plugin:
for plugin in CFG.plugins:
if plugin._name == self.plugin_selector:
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
else:
for plugin in CFG.plugins:
if not plugin.can_handle_post_prompt():
continue
self.plugins_prompt_generator = plugin.post_prompt(
self.plugins_prompt_generator
)
def generate_input_values(self):
input_values = {
"input": self.current_user_input,
"constraints": self.__list_to_prompt_str(
list(self.plugins_prompt_generator.constraints)
),
"commands_infos": self.plugins_prompt_generator.generate_commands_string(),
}
return input_values
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
## plugin command run
return execute_command(
str(prompt_response.command.get("name")),
prompt_response.command.get("args", {}),
self.plugins_prompt_generator,
)
def chat_show(self):
super().chat_show()
def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
def generate(self, p) -> str:
return super().generate(p)
@property
def chat_type(self) -> str:
return ChatScene.ChatAgent.value

View File

@@ -0,0 +1,23 @@
from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
{
"messages": [
{"type": "human", "data": {"content": "查询xxx", "example": True}},
{
"type": "ai",
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"speak\": \"thoughts summary to say to user\",
\"command\": {\"name\": \"command name\", \"args\": {\"arg name\": \"value\"}},
}""",
"example": True,
},
},
]
},
]
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@@ -0,0 +1,46 @@
import json
from typing import Dict, NamedTuple
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class PluginAction(NamedTuple):
command: Dict
speak: str = ""
thoughts: str = ""
class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
clean_json_str = super().parse_prompt_response(model_out_text)
print(clean_json_str)
if not clean_json_str:
raise ValueError("model server response not have json!")
try:
response = json.loads(clean_json_str)
except Exception as e:
raise ValueError("model server out not fllow the prompt!")
speak = ""
thoughts = ""
for key in sorted(response):
if key.strip() == "command":
command = response[key]
if key.strip() == "thoughts":
thoughts = response[key]
if key.strip() == "speak":
speak = response[key]
return PluginAction(command, speak, thoughts)
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}")
view_text = f"##### {speak}" + "\n" + str(data)
return view_text
def get_format_instructions(self) -> str:
pass

View File

@@ -0,0 +1,78 @@
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle, ExampleType
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
from pilot.scene.chat_execution.example import plugin_example
CFG = Config()
_PROMPT_SCENE_DEFINE_EN = "You are a universal AI assistant."
_DEFAULT_TEMPLATE_EN = """
You need to analyze the user goals and, under the given constraints, prioritize using one of the following tools to solve the user goals.
Tool list:
{tool_list}
Constraint:
1. After selecting an available tool, please ensure that the output results include the following parts to use the tool:
<api-call><name>Selected Tool name</name> <arg1>Parameter value</arg1><arg2 >Parameter value 2</arg2></api-call>
2. If you cannot analyze the exact tool for the problem, you can consider using the search engine tool among the tools first.
3. Parameter content may need to be inferred based on the user's goals, not just extracted from text
4. If you cannot find a suitable tool, please answer Unable to complete the goal.
{expand_constraints}
User goals:
{user_goal}
"""
_PROMPT_SCENE_DEFINE_ZH = "你是一个通用AI助手"
_DEFAULT_TEMPLATE_ZH = """
请一步步思考,如何在满足下面约束条件的前提下,回答或解决用户问题或目标。
工具列表:
{tool_list}
约束条件:
1. 找到可用的工具后,请确保输出结果包含以下内容用来使用工具:<api-call><name>Selected Tool name</name> <arg1>Parameter value</arg1><arg2 >Parameter value 2</arg2></api-call>
2.任务重可以使用多个工具,上面约束的方式生成每个工具的调用,对于工具使用的提示文本,需要在工具使用前生成
3.如果有多个工具被使用,后续工具需要第一个工具的结果作为参数的, 使用如下文本来替代参数值:<api1-result>
4.如果对于问题无法理解和解决,可以考虑优先使用工具中的搜索引擎工具
5.参数内容可能需要根据用户的目标推理得到,不仅仅是从文本提取
6.如果中无法找到合适的工具,请回答无法完成目标。
7.约束条件和工具信息作为推理过程的辅助信息,不要表达在给用户的输出内容中
{expand_constraints}
用户目标:
{user_goal}
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
_PROMPT_SCENE_DEFINE=(
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
RESPONSE_FORMAT = None
EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
PROMPT_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ChatAgent.value(),
input_variables=["tool_list", "expand_constraints", "user_goal"],
response_format=None,
template_define=_PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
# example_selector=plugin_example,
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@@ -1 +0,0 @@

View File

@@ -6,6 +6,7 @@ from typing import Optional, Any
from dataclasses import dataclass, field
from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR
from pilot.componet import SystemApp
from pilot.utils.parameter_utils import BaseParameters
from pilot.base_modules.meta_data.meta_data import ddl_init_and_upgrade
@@ -30,7 +31,7 @@ def async_db_summery(system_app: SystemApp):
def server_init(args, system_app: SystemApp):
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
from pilot.base_modules.agent.plugins import scan_plugins
from pilot.base_modules.agent.plugins_util import scan_plugins
# logger.info(f"args: {args}")
@@ -43,7 +44,7 @@ def server_init(args, system_app: SystemApp):
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
cfg.set_plugins(scan_plugins(PLUGINS_DIR, cfg.debug_mode))
# Loader plugins and commands
command_categories = [
@@ -126,3 +127,4 @@ class WebWerverParameters(BaseParameters):
},
)
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})