chore:doc and fmt

This commit is contained in:
aries_ckt
2023-10-17 22:00:01 +08:00
parent b4ee95c0d1
commit ef8cd442a5
7 changed files with 93 additions and 96 deletions

View File

@@ -14,7 +14,7 @@ project = "DB-GPT"
copyright = "2023, csunny" copyright = "2023, csunny"
author = "csunny" author = "csunny"
version = "👏👏 0.3.9" version = "👏👏 0.4.0"
html_title = project + " " + version html_title = project + " " + version
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -92,4 +92,9 @@ pip install chromadb==0.4.10
```commandline ```commandline
pip install langchain>=0.0.286 pip install langchain>=0.0.286
##### Q9: In Centos OS, No matching distribution found for setuptools_scm
```commandline
pip install --use-pep517 fschat
``` ```

View File

@@ -69,11 +69,17 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}") logger.info(f"agent_hub_update:{update_param.__dict__}")
try: try:
agent_hub = AgentHub(PLUGINS_DIR) agent_hub = AgentHub(PLUGINS_DIR)
branch = update_param.branch if update_param.branch is not None and len(update_param.branch) > 0 else "main" branch = (
authorization = update_param.authorization if update_param.branch is not None and len(update_param.branch) > 0 else None update_param.branch
agent_hub.refresh_hub_from_git( if update_param.branch is not None and len(update_param.branch) > 0
update_param.url, branch, authorization else "main"
) )
authorization = (
update_param.authorization
if update_param.branch is not None and len(update_param.branch) > 0
else None
)
agent_hub.refresh_hub_from_git(update_param.url, branch, authorization)
return Result.succ(None) return Result.succ(None)
except Exception as e: except Exception as e:
logger.error("Agent Hub Update Error!", e) logger.error("Agent Hub Update Error!", e)

View File

@@ -9,30 +9,33 @@ from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session from pilot.base_modules.meta_data.meta_data import Base, engine, session
class MyPluginEntity(Base): class MyPluginEntity(Base):
__tablename__ = 'my_plugin' __tablename__ = "my_plugin"
id = Column(Integer, primary_key=True, comment="autoincrement id") id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String(255), nullable=True, comment="user's tenant") tenant = Column(String(255), nullable=True, comment="user's tenant")
user_code = Column(String(255), nullable=False, comment="user code") user_code = Column(String(255), nullable=False, comment="user code")
user_name = Column(String(255), nullable=True, comment="user name") user_name = Column(String(255), nullable=True, comment="user name")
name = Column(String(255), unique=True, nullable=False, comment="plugin name") name = Column(String(255), unique=True, nullable=False, comment="plugin name")
file_name = Column(String(255), nullable=False, comment="plugin package file name") file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type") type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version") version = Column(String(255), comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count") use_count = Column(
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count") Integer, nullable=True, default=0, comment="plugin total use count"
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
__table_args__ = (
UniqueConstraint('user_code','name', name="uk_name"),
) )
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
created_at = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
__table_args__ = (UniqueConstraint("user_code", "name", name="uk_name"),)
class MyPluginDao(BaseDao[MyPluginEntity]): class MyPluginDao(BaseDao[MyPluginEntity]):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
database="dbgpt", orm_base=Base, db_engine =engine , session= session database="dbgpt", orm_base=Base, db_engine=engine, session=session
) )
def add(self, engity: MyPluginEntity): def add(self, engity: MyPluginEntity):
@@ -60,13 +63,11 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
session.commit() session.commit()
return updated.id return updated.id
def get_by_user(self, user: str)->list[MyPluginEntity]: def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_session() session = self.get_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
if user: if user:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
MyPluginEntity.user_code == user
)
result = my_plugins.all() result = my_plugins.all()
session.close() session.close()
return result return result
@@ -75,86 +76,58 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
session = self.get_session() session = self.get_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
if user: if user:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
MyPluginEntity.user_code == user my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
)
my_plugins = my_plugins.filter(
MyPluginEntity.name == plugin
)
result = my_plugins.first() result = my_plugins.first()
session.close() session.close()
return result return result
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
session = self.get_session() session = self.get_session()
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count() all_count = my_plugins.count()
if query.id is not None: if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None: if query.name is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
MyPluginEntity.name == query.name
)
if query.tenant is not None: if query.tenant is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
MyPluginEntity.tenant == query.tenant
)
if query.type is not None: if query.type is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
MyPluginEntity.type == query.type
)
if query.user_code is not None: if query.user_code is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
MyPluginEntity.user_code == query.user_code
)
if query.user_name is not None: if query.user_name is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
MyPluginEntity.user_name == query.user_name
)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc()) my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size) my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
result = my_plugins.all() result = my_plugins.all()
session.close() session.close()
total_pages = all_count // page_size total_pages = all_count // page_size
if all_count % page_size != 0: if all_count % page_size != 0:
total_pages += 1 total_pages += 1
return result, total_pages, all_count return result, total_pages, all_count
def count(self, query: MyPluginEntity): def count(self, query: MyPluginEntity):
session = self.get_session() session = self.get_session()
my_plugins = session.query(func.count(MyPluginEntity.id)) my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None: if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id) my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None: if query.name is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
MyPluginEntity.name == query.name
)
if query.type is not None: if query.type is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
MyPluginEntity.type == query.type
)
if query.tenant is not None: if query.tenant is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
MyPluginEntity.tenant == query.tenant
)
if query.user_code is not None: if query.user_code is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
MyPluginEntity.user_code == query.user_code
)
if query.user_name is not None: if query.user_name is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
MyPluginEntity.user_name == query.user_name
)
count = my_plugins.scalar() count = my_plugins.scalar()
session.close() session.close()
return count return count
def delete(self, plugin_id: int): def delete(self, plugin_id: int):
session = self.get_session() session = self.get_session()
if plugin_id is None: if plugin_id is None:
@@ -162,9 +135,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
query = MyPluginEntity(id=plugin_id) query = MyPluginEntity(id=plugin_id)
my_plugins = session.query(MyPluginEntity) my_plugins = session.query(MyPluginEntity)
if query.id is not None: if query.id is not None:
my_plugins = my_plugins.filter( my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
MyPluginEntity.id == query.id
)
my_plugins.delete() my_plugins.delete()
session.commit() session.commit()
session.close() session.close()

View File

@@ -4,7 +4,7 @@ import os
import glob import glob
import shutil import shutil
from fastapi import UploadFile from fastapi import UploadFile
from typing import Any from typing import Any
import tempfile import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
@@ -38,12 +38,16 @@ class AgentHub:
download_param = json.loads(plugin_entity.download_param) download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name") branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization") authorization = download_param.get("authorization")
file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization) file_name = self.__download_from_git(
plugin_entity.storage_url, branch_name, authorization
)
# add to my plugins and edit hub status # add to my plugins and edit hub status
plugin_entity.installed = plugin_entity.installed + 1 plugin_entity.installed = plugin_entity.installed + 1
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user_name, plugin_name) my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(
user_name, plugin_name
)
if my_plugin_entity is None: if my_plugin_entity is None:
my_plugin_entity = self.__build_my_plugin(plugin_entity) my_plugin_entity = self.__build_my_plugin(plugin_entity)
my_plugin_entity.file_name = file_name my_plugin_entity.file_name = file_name
@@ -71,7 +75,9 @@ class AgentHub:
logger.error("install pluguin exception!", e) logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}") raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
else: else:
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!") raise ValueError(
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
)
else: else:
raise ValueError(f"Can't Find Plugin {plugin_name}!") raise ValueError(f"Can't Find Plugin {plugin_name}!")
@@ -83,11 +89,13 @@ class AgentHub:
plugin_entity.installed = plugin_entity.installed - 1 plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session: with self.hub_dao.get_session() as session:
try: try:
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name) my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user: if user:
my_plugin_q.filter(MyPluginEntity.user_code == user) my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete() my_plugin_q.delete()
if plugin_entity is not None: if plugin_entity is not None:
session.merge(plugin_entity) session.merge(plugin_entity)
session.commit() session.commit()
except: except:
@@ -102,10 +110,12 @@ class AgentHub:
have_installed = True have_installed = True
break break
if not have_installed: if not have_installed:
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1] plugin_repo_name = (
files = glob.glob( plugin_entity.storage_url.replace(".git", "")
os.path.join(self.plugin_dir, f"{plugin_repo_name}*") .strip("/")
.split("/")[-1]
) )
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
for file in files: for file in files:
os.remove(file) os.remove(file)
else: else:
@@ -125,9 +135,16 @@ class AgentHub:
my_plugin_entity.version = hub_plugin.version my_plugin_entity.version = hub_plugin.version
return my_plugin_entity return my_plugin_entity
def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = "main", authorization: str = None): def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = "main",
authorization: str = None,
):
logger.info("refresh_hub_by_git start!") logger.info("refresh_hub_by_git start!")
update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization) update_from_git(
self.temp_hub_file_path, github_repo, branch_name, authorization
)
git_plugins = scan_plugins(self.temp_hub_file_path) git_plugins = scan_plugins(self.temp_hub_file_path)
try: try:
for git_plugin in git_plugins: for git_plugin in git_plugins:
@@ -139,13 +156,13 @@ class AgentHub:
plugin_hub_info.type = "" plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT') plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
plugin_hub_info.email = getattr(git_plugin, '_email', '') plugin_hub_info.email = getattr(git_plugin, "_email", "")
download_param = {} download_param = {}
if branch_name: if branch_name:
download_param['branch_name'] = branch_name download_param["branch_name"] = branch_name
if authorization and len(authorization) > 0: if authorization and len(authorization) > 0:
download_param['authorization'] = authorization download_param["authorization"] = authorization
plugin_hub_info.download_param = json.dumps(download_param) plugin_hub_info.download_param = json.dumps(download_param)
plugin_hub_info.installed = 0 plugin_hub_info.installed = 0
@@ -156,15 +173,12 @@ class AgentHub:
except Exception as e: except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}") raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
async def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User): async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
# We can not move temp file in windows system when we open file in context of `with` # We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename) file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path): if os.path.exists(file_path):
os.remove(file_path) os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp( tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
dir=os.path.join(self.plugin_dir)
)
with os.fdopen(tmp_fd, "wb") as tmp: with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read()) tmp.write(await doc_file.read())
shutil.move( shutil.move(
@@ -174,15 +188,17 @@ class AgentHub:
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename) my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
if user is None or len(user) <=0: if user is None or len(user) <= 0:
user = Default_User user = Default_User
for my_plugin in my_plugins: for my_plugin in my_plugins:
my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(user, my_plugin._name) my_plugin_entiy = self.my_plugin_dao.get_by_user_and_plugin(
if my_plugin_entiy is None : user, my_plugin._name
my_plugin_entiy = MyPluginEntity() )
if my_plugin_entiy is None:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal" my_plugin_entiy.type = "Personal"
my_plugin_entiy.user_code = user my_plugin_entiy.user_code = user
my_plugin_entiy.user_name = user my_plugin_entiy.user_name = user
@@ -199,4 +215,3 @@ class AgentHub:
if not user: if not user:
user = Default_User user = Default_User
return self.my_plugin_dao.get_by_user(user) return self.my_plugin_dao.get_by_user(user)

View File

@@ -37,7 +37,7 @@ def server_init(args, system_app: SystemApp):
cfg = Config() cfg = Config()
cfg.SYSTEM_APP = system_app cfg.SYSTEM_APP = system_app
ddl_init_and_upgrade() # ddl_init_and_upgrade()
# load_native_plugins(cfg) # load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)

View File

@@ -316,8 +316,6 @@ def core_requires():
"jsonschema", "jsonschema",
# TODO move transformers to default # TODO move transformers to default
"transformers>=4.31.0", "transformers>=4.31.0",
"GitPython",
"alembic",
] ]
@@ -425,6 +423,8 @@ def default_requires():
"zhipuai", "zhipuai",
"dashscope", "dashscope",
"chardet", "chardet",
"GitPython",
"alembic",
] ]
setup_spec.extras["default"] += setup_spec.extras["framework"] setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["knowledge"] setup_spec.extras["default"] += setup_spec.extras["knowledge"]
@@ -465,7 +465,7 @@ init_install_requires()
setuptools.setup( setuptools.setup(
name="db-gpt", name="db-gpt",
packages=find_packages(exclude=("tests", "*.tests", "*.tests.*", "examples")), packages=find_packages(exclude=("tests", "*.tests", "*.tests.*", "examples")),
version="0.3.9", version="0.4.0",
author="csunny", author="csunny",
author_email="cfqcsunny@gmail.com", author_email="cfqcsunny@gmail.com",
description="DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment." description="DB-GPT is an experimental open-source project that uses localized GPT large models to interact with your data and environment."