refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -1,16 +1,10 @@
from sqlalchemy import Column, Integer, String, Index, Text, text
from sqlalchemy import UniqueConstraint
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata.meta_data import (
Base,
engine,
session,
META_DATA_DATABASE,
)
from dbgpt.storage.metadata import BaseDao, Model
class ConnectConfigEntity(Base):
class ConnectConfigEntity(Model):
"""db connect config entity"""
__tablename__ = "connect_config"
@@ -38,17 +32,9 @@ class ConnectConfigEntity(Base):
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
"""db connect config dao"""
def __init__(self):
super().__init__(
database=META_DATA_DATABASE,
orm_base=Base,
db_engine=engine,
session=session,
)
def update(self, entity: ConnectConfigEntity):
"""update db connect info"""
session = self.get_session()
session = self.get_raw_session()
try:
updated = session.merge(entity)
session.commit()
@@ -58,7 +44,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def delete(self, db_name: int):
""" "delete db connect info"""
session = self.get_session()
session = self.get_raw_session()
if db_name is None:
raise Exception("db_name is None")
@@ -70,7 +56,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
"""get db connect info by name"""
session = self.get_session()
session = self.get_raw_session()
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
result = db_connect.first()
@@ -99,7 +85,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
comment: comment
"""
try:
session = self.get_session()
session = self.get_raw_session()
from sqlalchemy import text
@@ -144,7 +130,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
old_db_conf = self.get_db_config(db_name)
if old_db_conf:
try:
session = self.get_session()
session = self.get_raw_session()
if not db_path:
update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
@@ -164,7 +150,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
"""add file db connect info"""
try:
session = self.get_session()
session = self.get_raw_session()
insert_statement = text(
"""
INSERT INTO connect_config(
@@ -194,7 +180,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_db_config(self, db_name):
"""get db config by name"""
session = self.get_session()
session = self.get_raw_session()
if db_name:
select_statement = text(
"""
@@ -221,7 +207,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_db_list(self):
"""get db list"""
session = self.get_session()
session = self.get_raw_session()
result = session.execute(text("SELECT * FROM connect_config"))
fields = [field[0] for field in result.cursor.description]
@@ -235,7 +221,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def delete_db(self, db_name):
"""delete db connect info"""
session = self.get_session()
session = self.get_raw_session()
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
params = {"db_name": db_name}
session.execute(delete_statement, params)