mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 12:37:14 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
@@ -2,16 +2,10 @@ from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, DateTime, func
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
|
||||
class MyPluginEntity(Base):
|
||||
class MyPluginEntity(Model):
|
||||
__tablename__ = "my_plugin"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -39,16 +33,8 @@ class MyPluginEntity(Base):
|
||||
|
||||
|
||||
class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def add(self, engity: MyPluginEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
my_plugin = MyPluginEntity(
|
||||
tenant=engity.tenant,
|
||||
user_code=engity.user_code,
|
||||
@@ -68,13 +54,13 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
return id
|
||||
|
||||
def update(self, entity: MyPluginEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
|
||||
def get_by_user(self, user: str) -> list[MyPluginEntity]:
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
@@ -83,7 +69,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
return result
|
||||
|
||||
def get_by_user_and_plugin(self, user: str, plugin: str) -> MyPluginEntity:
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
@@ -93,7 +79,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
return result
|
||||
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
all_count = my_plugins.count()
|
||||
if query.id is not None:
|
||||
@@ -122,7 +108,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
return result, total_pages, all_count
|
||||
|
||||
def count(self, query: MyPluginEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
my_plugins = session.query(func.count(MyPluginEntity.id))
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
@@ -143,7 +129,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
return count
|
||||
|
||||
def delete(self, plugin_id: int):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
query = MyPluginEntity(id=plugin_id)
|
||||
|
@@ -3,19 +3,13 @@ import pytz
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, DDL
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
||||
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
|
||||
|
||||
|
||||
class PluginHubEntity(Base):
|
||||
class PluginHubEntity(Model):
|
||||
__tablename__ = "plugin_hub"
|
||||
__table_args__ = {
|
||||
"mysql_charset": "utf8mb4",
|
||||
@@ -43,16 +37,8 @@ class PluginHubEntity(Base):
|
||||
|
||||
|
||||
class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def add(self, engity: PluginHubEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
timezone = pytz.timezone("Asia/Shanghai")
|
||||
plugin_hub = PluginHubEntity(
|
||||
name=engity.name,
|
||||
@@ -71,7 +57,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
return id
|
||||
|
||||
def update(self, entity: PluginHubEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
@@ -82,7 +68,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
def list(
|
||||
self, query: PluginHubEntity, page=1, page_size=20
|
||||
) -> list[PluginHubEntity]:
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
all_count = plugin_hubs.count()
|
||||
|
||||
@@ -111,7 +97,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
return result, total_pages, all_count
|
||||
|
||||
def get_by_storage_url(self, storage_url):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.storage_url == storage_url)
|
||||
result = plugin_hubs.all()
|
||||
@@ -119,7 +105,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
return result
|
||||
|
||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
|
||||
result = plugin_hubs.first()
|
||||
@@ -127,7 +113,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
return result
|
||||
|
||||
def count(self, query: PluginHubEntity):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
plugin_hubs = session.query(func.count(PluginHubEntity.id))
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
@@ -146,7 +132,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
return count
|
||||
|
||||
def delete(self, plugin_id: int):
|
||||
session = self.get_session()
|
||||
session = self.get_raw_session()
|
||||
if plugin_id is None:
|
||||
raise Exception("plugin_id is None")
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
|
@@ -59,18 +59,12 @@ class AgentHub:
|
||||
else:
|
||||
my_plugin_entity.user_code = Default_User
|
||||
|
||||
with self.hub_dao.get_session() as session:
|
||||
try:
|
||||
if my_plugin_entity.id is None:
|
||||
session.add(my_plugin_entity)
|
||||
else:
|
||||
session.merge(my_plugin_entity)
|
||||
session.merge(plugin_entity)
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
logger.error("install merge roll back!" + str(e))
|
||||
session.rollback()
|
||||
with self.hub_dao.session() as session:
|
||||
if my_plugin_entity.id is None:
|
||||
session.add(my_plugin_entity)
|
||||
else:
|
||||
session.merge(my_plugin_entity)
|
||||
session.merge(plugin_entity)
|
||||
except Exception as e:
|
||||
logger.error("install pluguin exception!", e)
|
||||
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
||||
@@ -87,19 +81,15 @@ class AgentHub:
|
||||
my_plugin_entity = self.my_plugin_dao.get_by_user_and_plugin(user, plugin_name)
|
||||
if plugin_entity is not None:
|
||||
plugin_entity.installed = plugin_entity.installed - 1
|
||||
with self.hub_dao.get_session() as session:
|
||||
try:
|
||||
my_plugin_q = session.query(MyPluginEntity).filter(
|
||||
MyPluginEntity.name == plugin_name
|
||||
)
|
||||
if user:
|
||||
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||
my_plugin_q.delete()
|
||||
if plugin_entity is not None:
|
||||
session.merge(plugin_entity)
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
with self.hub_dao.session() as session:
|
||||
my_plugin_q = session.query(MyPluginEntity).filter(
|
||||
MyPluginEntity.name == plugin_name
|
||||
)
|
||||
if user:
|
||||
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||
my_plugin_q.delete()
|
||||
if plugin_entity is not None:
|
||||
session.merge(plugin_entity)
|
||||
|
||||
if plugin_entity is not None:
|
||||
# delete package file if not use
|
||||
|
Reference in New Issue
Block a user