refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -2,16 +2,10 @@ from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, func
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
class MyPluginEntity(Base):
class MyPluginEntity(Model):
__tablename__ = "my_plugin"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -39,16 +33,8 @@ class MyPluginEntity(Base):
class MyPluginDao(BaseDao[MyPluginEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def add(self, engity: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
my_plugin = MyPluginEntity(
tenant=engity.tenant,
user_code=engity.user_code,
@@ -68,13 +54,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return id
def update(self, entity: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
updated = session.merge(entity)
session.commit()
return updated.id
def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
@@ -83,7 +69,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return result
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
@@ -93,7 +79,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return result
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
@@ -122,7 +108,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return result, total_pages, all_count
def count(self, query: MyPluginEntity):
session = self.get_session()
session = self.get_raw_session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
@@ -143,7 +129,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
return count
def delete(self, plugin_id: int):
session = self.get_session()
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
query = MyPluginEntity(id=plugin_id)

View File

@@ -3,19 +3,13 @@ import pytz
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
# TODO We should consider that the production environment does not have permission to execute the DDL
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
class PluginHubEntity(Base):
class PluginHubEntity(Model):
__tablename__ = "plugin_hub"
__table_args__ = {
"mysql_charset": "utf8mb4",
@@ -43,16 +37,8 @@ class PluginHubEntity(Base):
class PluginHubDao(BaseDao[PluginHubEntity]):
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def add(self, engity: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity(
name=engity.name,
@@ -71,7 +57,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return id
def update(self, entity: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
try:
updated = session.merge(entity)
session.commit()
@@ -82,7 +68,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def list(
self, query: PluginHubEntity, page=1, page_size=20
) -> list[PluginHubEntity]:
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()
@@ -111,7 +97,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return result, total_pages, all_count
def get_by_storage_url(self, storage_url):
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
result = plugin_hubs.all()
@@ -119,7 +105,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return result
def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first()
@@ -127,7 +113,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return result
def count(self, query: PluginHubEntity):
session = self.get_session()
session = self.get_raw_session()
plugin_hubs = session.query(func.count(PluginHubEntity.id))
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
@@ -146,7 +132,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
return count
def delete(self, plugin_id: int):
session = self.get_session()
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity)

View File

@@ -59,18 +59,12 @@ class AgentHub:
else:
my_plugin_entity.user_code = Default_User
with self.hub_dao.get_session() as session:
try:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
session.commit()
session.close()
except Exception as e:
logger.error("install merge roll back!" + str(e))
session.rollback()
with self.hub_dao.session() as session:
if my_plugin_entity.id is None:
session.add(my_plugin_entity)
else:
session.merge(my_plugin_entity)
session.merge(plugin_entity)
except Exception as e:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
@@ -87,19 +81,15 @@ class AgentHub:
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
if plugin_entity is not None:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session:
try:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)
session.commit()
except:
session.rollback()
with self.hub_dao.session() as session:
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
if plugin_entity is not None:
session.merge(plugin_entity)
if plugin_entity is not None:
# delete package file if not use