chore:doc and fmt

This commit is contained in:
aries_ckt
2023-10-17 22:00:01 +08:00
parent b4ee95c0d1
commit ef8cd442a5
7 changed files with 93 additions and 96 deletions

View File

@@ -14,7 +14,7 @@ project = "DB-GPT"
copyright = "2023, csunny" copyright = "2023, csunny"
author = "csunny" author = "csunny"
version = "👏👏 0.3.9" version = "👏👏 0.4.0"
html_title = project + " " + version html_title = project + " " + version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -92,4 +92,9 @@ pip install chromadb==0.4.10
```commandline ```commandline
pip install langchain>=0.0.286 pip install langchain>=0.0.286
##### Q9: In Centos OS, No matching distribution found for setuptools_scm
```commandline
pip install --use-pep517 fschat
``` ```

View File

@@ -69,11 +69,17 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}") logger.info(f"agent_hub_update:{update_param.__dict__}")
try: try:
agent_hub = AgentHub(PLUGINS_DIR) agent_hub = AgentHub(PLUGINS_DIR)
branch = update_param.branch if update_param.branch is not None and len(update_param.branch) > 0 else "main" branch = (
authorization = update_param.authorization if update_param.branch is not None and len(update_param.branch) > 0 else None update_param.branch
agent_hub.refresh_hub_from_git( if update_param.branch is not None and len(update_param.branch) > 0
update_param.url, branch, authorization else "main"
) )
authorization = (
update_param.authorization
if update_param.branch is not None and len(update_param.branch) > 0
else None
)
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
return Result.succ(None) return Result.succ(None)
except Exception as e: except Exception as e:
logger.error("Agent Hub Update Error!", e) logger.error("Agent Hub Update Error!", e)

View File

@@ -9,9 +9,8 @@ from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session from pilot.base_modules.meta_data.meta_data import Base, engine, session
class MyPluginEntity(Base): class MyPluginEntity(Base):
__tablename__ = 'my_plugin' __tablename__ = "my_plugin"
id = Column(Integer, primary_key=True, comment="autoincrement id") id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String(255), nullable=True, comment="user's tenant") tenant = Column(String(255), nullable=True, comment="user's tenant")
@@ -21,12 +20,16 @@ class MyPluginEntity(Base):
file_name = Column(String(255), nullable=False, comment="plugin package file name") file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type") type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version") version = Column(String(255), comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count") use_count = Column(
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count") Integer, nullable=True, default=0, comment="plugin total use count"
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
__table_args__ = (
UniqueConstraint('user_code','name', name="uk_name"),
) )
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
created_at = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
__table_args__ = (UniqueConstraint("user_code", "name", name="uk_name"),)
class MyPluginDao(BaseDao[MyPluginEntity]): class MyPluginDao(BaseDao[MyPluginEntity]):
@@ -64,9 +67,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
session = self.get_session() session = self.get_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
if user: if user:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
MyPluginEntity.user_code == user
)
result = my_plugins.all() result = my_plugins.all()
session.close() session.close()
return result return result
@@ -75,17 +76,12 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
session = self.get_session() session = self.get_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
if user: if user:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
MyPluginEntity.user_code == user my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
)
my_plugins = my_plugins.filter(
MyPluginEntity.name == plugin
)
result = my_plugins.first() result = my_plugins.first()
session.close() session.close()
return result return result
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]: def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_session() session = self.get_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
@@ -93,25 +89,15 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
if query.id is not None: if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None: if query.name is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
MyPluginEntity.name == query.name
)
if query.tenant is not None: if query.tenant is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
MyPluginEntity.tenant == query.tenant
)
if query.type is not None: if query.type is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
MyPluginEntity.type == query.type
)
if query.user_code is not None: if query.user_code is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
MyPluginEntity.user_code == query.user_code
)
if query.user_name is not None: if query.user_name is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
MyPluginEntity.user_name == query.user_name
)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc()) my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size) my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
@@ -121,40 +107,27 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
if all_count % page_size != 0: if all_count % page_size != 0:
total_pages += 1 total_pages += 1
return result, total_pages, all_count return result, total_pages, all_count
def count(self, query: MyPluginEntity): def count(self, query: MyPluginEntity):
session = self.get_session() session = self.get_session()
my_plugins = session.query(func.count(MyPluginEntity.id)) my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None: if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None: if query.name is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
MyPluginEntity.name == query.name
)
if query.type is not None: if query.type is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
MyPluginEntity.type == query.type
)
if query.tenant is not None: if query.tenant is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
MyPluginEntity.tenant == query.tenant
)
if query.user_code is not None: if query.user_code is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
MyPluginEntity.user_code == query.user_code
)
if query.user_name is not None: if query.user_name is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
MyPluginEntity.user_name == query.user_name
)
count = my_plugins.scalar() count = my_plugins.scalar()
session.close() session.close()
return count return count
def delete(self, plugin_id: int): def delete(self, plugin_id: int):
session = self.get_session() session = self.get_session()
if plugin_id is None: if plugin_id is None:
@@ -162,9 +135,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
query = MyPluginEntity(id=plugin_id) query = MyPluginEntity(id=plugin_id)
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
if query.id is not None: if query.id is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
MyPluginEntity.id == query.id
)
my_plugins.delete() my_plugins.delete()
session.commit() session.commit()
session.close() session.close()

View File

@@ -38,12 +38,16 @@ class AgentHub:
download_param = json.loads(plugin_entity.download_param) download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name") branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization") authorization = download_param.get("authorization")
file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization) file_name = self.__download_from_git(
plugin_entity.storage_url, branch_name, authorization
)
# add to my plugins and edit hub status # add to my plugins and edit hub status
plugin_entity.installed = plugin_entity.installed + 1 plugin_entity.installed = plugin_entity.installed + 1
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user_name, plugin_name) my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
user_name, plugin_name
)
if my_plugin_entity is None: if my_plugin_entity is None:
my_plugin_entity = self.__build_my_plugin(plugin_entity) my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name my_plugin_entity.file_name = file_name
@@ -71,7 +75,9 @@ class AgentHub:
logger.error("install pluguin exception!", e) logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}") raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
else: else:
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!") raise ValueError(
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
)
else: else:
raise ValueError(f"Can't Find Plugin {plugin_name}!") raise ValueError(f"Can't Find Plugin {plugin_name}!")
@@ -83,7 +89,9 @@ class AgentHub:
plugin_entity.installed = plugin_entity.installed - 1 plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session: with self.hub_dao.get_session() as session:
try: try:
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name) my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user: if user:
my_plugin_q.filter(MyPluginEntity.user_code == user) my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete() my_plugin_q.delete()
@@ -102,10 +110,12 @@ class AgentHub:
have_installed = True have_installed = True
break break
if not have_installed: if not have_installed:
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1] plugin_repo_name = (
files = glob.glob( plugin_entity.storage_url.replace(".git", "")
os.path.join(self.plugin_dir, f"{plugin_repo_name}*") .strip("/")
.split("/")[-1]
) )
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
for file in files: for file in files:
os.remove(file) os.remove(file)
else: else:
@@ -125,9 +135,16 @@ class AgentHub:
my_plugin_entity.version = hub_plugin.version my_plugin_entity.version = hub_plugin.version
return my_plugin_entity return my_plugin_entity
def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = "main", authorization: str = None): def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = "main",
authorization: str = None,
):
logger.info("refresh_hub_by_git start!") logger.info("refresh_hub_by_git start!")
update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization) update_from_git(
self.temp_hub_file_path, github_repo, branch_name, authorization
)
git_plugins = scan_plugins(self.temp_hub_file_path) git_plugins = scan_plugins(self.temp_hub_file_path)
try: try:
for git_plugin in git_plugins: for git_plugin in git_plugins:
@@ -139,13 +156,13 @@ class AgentHub:
plugin_hub_info.type = "" plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT') plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
plugin_hub_info.email = getattr(git_plugin, '_email', '') plugin_hub_info.email = getattr(git_plugin, "_email", "")
download_param = {} download_param = {}
if branch_name: if branch_name:
download_param['branch_name'] = branch_name download_param["branch_name"] = branch_name
if authorization and len(authorization) > 0: if authorization and len(authorization) > 0:
download_param['authorization'] = authorization download_param["authorization"] = authorization
plugin_hub_info.download_param = json.dumps(download_param) plugin_hub_info.download_param = json.dumps(download_param)
plugin_hub_info.installed = 0 plugin_hub_info.installed = 0
@@ -157,14 +174,11 @@ class AgentHub:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}") raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User): async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
# We can not move temp file in windows system when we open file in context of `with` # We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename) file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path): if os.path.exists(file_path):
os.remove(file_path) os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp( tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
dir=os.path.join(self.plugin_dir)
)
with os.fdopen(tmp_fd, "wb") as tmp: with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read()) tmp.write(await doc_file.read())
shutil.move( shutil.move(
@@ -178,7 +192,9 @@ class AgentHub:
user = Default_User user = Default_User
for my_plugin in my_plugins: for my_plugin in my_plugins:
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(user, my_plugin._name) my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
user, my_plugin._name
)
if my_plugin_entiy is None: if my_plugin_entiy is None:
my_plugin_entiy = MyPluginEntity() my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name my_plugin_entiy.name = my_plugin._name
@@ -199,4 +215,3 @@ class AgentHub:
if not user: if not user:
user = Default_User user = Default_User
return self.my_plugin_dao.get_by_user(user) return self.my_plugin_dao.get_by_user(user)

View File

@@ -37,7 +37,7 @@ def server_init(args, system_app: SystemApp):
cfg = Config() cfg = Config()
cfg.SYSTEM_APP = system_app cfg.SYSTEM_APP = system_app
ddl_init_and_upgrade() # ddl_init_and_upgrade()
# load_native_plugins(cfg) # load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)

View File

@@ -316,8 +316,6 @@ def core_requires():
"jsonschema", "jsonschema",
# TODO move transformers to default # TODO move transformers to default
"transformers>=4.31.0", "transformers>=4.31.0",
"GitPython",
"alembic",
] ]
@@ -425,6 +423,8 @@ def default_requires():
"zhipuai", "zhipuai",
"dashscope", "dashscope",
"chardet", "chardet",
"GitPython",
"alembic",
] ]
setup_spec.extras["default"] += setup_spec.extras["framework"] setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["knowledge"] setup_spec.extras["default"] += setup_spec.extras["knowledge"]
@@ -465,7 +465,7 @@ init_install_requires()
setuptools.setup( setuptools.setup(
name="db-gpt", name="db-gpt",
packages=find_packages(exclude=("tests", "*.tests", "*.tests.*", "examples")), packages=find_packages(exclude=("tests", "*.tests", "*.tests.*", "examples")),
version="0.3.9", version="0.4.0",
author="csunny", author="csunny",
author_email="cfqcsunny@gmail.com", author_email="cfqcsunny@gmail.com",
description="DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment." description="DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment."