DB-GPT/dbgpt/serve/agent/hub/db/my_plugin_db.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

162 lines
6.5 KiB
Python

from datetime import datetime
from typing import List
from sqlalchemy import Column, DateTime, Integer, String, UniqueConstraint, func
from dbgpt.serve.agent.hub.model.model import MyPluginVO
from dbgpt.storage.metadata import BaseDao, Model
class MyPluginEntity(Model):
__tablename__ = "my_plugin"
id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String(255), nullable=True, comment="user's tenant")
user_code = Column(String(255), nullable=False, comment="user code")
user_name = Column(String(255), nullable=True, comment="user name")
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
use_count = Column(
Integer, nullable=True, default=0, comment="plugin total use count"
)
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
gmt_created = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
UniqueConstraint("user_code", "name", name="uk_name")
@classmethod
def to_vo(cls, entities: List["MyPluginEntity"]) -> List[MyPluginVO]:
results = []
for entity in entities:
results.append(
MyPluginVO(
id=entity.id,
tenant=entity.tenant,
user_code=entity.user_code,
user_name=entity.user_name,
sys_code=entity.sys_code,
name=entity.name,
file_name=entity.file_name,
type=entity.type,
version=entity.version,
use_count=entity.use_count,
succ_count=entity.succ_count,
gmt_created=entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S"),
)
)
return results
class MyPluginDao(BaseDao):
def add(self, engity: MyPluginEntity):
session = self.get_raw_session()
my_plugin = MyPluginEntity(
tenant=engity.tenant,
user_code=engity.user_code,
user_name=engity.user_name,
name=engity.name,
type=engity.type,
version=engity.version,
use_count=engity.use_count or 0,
succ_count=engity.succ_count or 0,
sys_code=engity.sys_code,
gmt_created=datetime.now(),
)
session.add(my_plugin)
session.commit()
id = my_plugin.id
session.close()
return id
def raw_update(self, entity: MyPluginEntity):
session = self.get_raw_session()
updated = session.merge(entity)
session.commit()
return updated.id
def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
result = my_plugins.all()
session.close()
return result
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
my_plugins = my_plugins.filter(MyPluginEntity.name == plugin)
result = my_plugins.first()
session.close()
return result
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.tenant is not None:
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.type is not None:
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.user_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
result = my_plugins.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def count(self, query: MyPluginEntity):
session = self.get_raw_session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.type is not None:
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.tenant is not None:
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.user_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
if query.sys_code is not None:
my_plugins = my_plugins.filter(MyPluginEntity.sys_code == query.sys_code)
count = my_plugins.scalar()
session.close()
return count
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id)
my_plugins = session.query(MyPluginEntity)
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
my_plugins.delete()
session.commit()
session.close()