refactor: Refactor datasource module (#1309)

This commit is contained in:
Fangyin Cheng
2024-03-18 18:06:40 +08:00
committed by GitHub
parent 84bedee306
commit 4970c9f813
108 changed files with 1194 additions and 1066 deletions

View File

@@ -1,10 +1,17 @@
"""DB Model for connect_config."""
import logging
from typing import Optional
from sqlalchemy import Column, Index, Integer, String, Text, UniqueConstraint, text
from dbgpt.storage.metadata import BaseDao, Model
logger = logging.getLogger(__name__)
class ConnectConfigEntity(Model):
"""db connect config entity"""
"""DB connector config entity."""
__tablename__ = "connect_config"
id = Column(
@@ -28,32 +35,10 @@ class ConnectConfigEntity(Model):
class ConnectConfigDao(BaseDao):
"""db connect config dao"""
"""DB connector config dao."""
def update(self, entity: ConnectConfigEntity):
"""update db connect info"""
session = self.get_raw_session()
try:
updated = session.merge(entity)
session.commit()
return updated.id
finally:
session.close()
def delete(self, db_name: int):
""" "delete db connect info"""
session = self.get_raw_session()
if db_name is None:
raise Exception("db_name is None")
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
db_connect.delete()
session.commit()
session.close()
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
"""get db connect info by name"""
def get_by_names(self, db_name: str) -> Optional[ConnectConfigEntity]:
"""Get db connect info by name."""
session = self.get_raw_session()
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
@@ -71,8 +56,8 @@ class ConnectConfigDao(BaseDao):
db_pwd: str,
comment: str = "",
):
"""
add db connect info
"""Add db connect info.
Args:
db_name: db name
db_type: db type
@@ -90,9 +75,9 @@ class ConnectConfigDao(BaseDao):
insert_statement = text(
"""
INSERT INTO connect_config (
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
) VALUES (
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd,
comment) VALUES (:db_name, :db_type, :db_path, :db_host, :db_port
, :db_user, :db_pwd, :comment
)
"""
)
@@ -111,7 +96,7 @@ class ConnectConfigDao(BaseDao):
session.commit()
session.close()
except Exception as e:
print("add db connect info error" + str(e))
logger.warning("add db connect info error" + str(e))
def update_db_info(
self,
@@ -124,37 +109,43 @@ class ConnectConfigDao(BaseDao):
db_pwd: str = "",
comment: str = "",
):
"""update db connect info"""
"""Update db connect info."""
old_db_conf = self.get_db_config(db_name)
if old_db_conf:
try:
session = self.get_raw_session()
if not db_path:
update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
f"UPDATE connect_config set db_type='{db_type}', "
f"db_host='{db_host}', db_port={db_port}, db_user='{db_user}', "
f"db_pwd='{db_pwd}', comment='{comment}' where "
f"db_name='{db_name}'"
)
else:
update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
f"UPDATE connect_config set db_type='{db_type}', "
f"db_path='{db_path}', comment='{comment}' where "
f"db_name='{db_name}'"
)
session.execute(update_statement)
session.commit()
session.close()
except Exception as e:
print("edit db connect info error" + str(e))
logger.warning("edit db connect info error" + str(e))
return True
raise ValueError(f"{db_name} not have config info!")
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
"""add file db connect info"""
"""Add file db connect info."""
try:
session = self.get_raw_session()
insert_statement = text(
"""
INSERT INTO connect_config(
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
) VALUES (
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd,
comment) VALUES (
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd
, :comment
)
"""
)
@@ -174,19 +165,19 @@ class ConnectConfigDao(BaseDao):
session.commit()
session.close()
except Exception as e:
print("add db connect info error" + str(e))
logger.warning("add db connect info error" + str(e))
def get_db_config(self, db_name):
"""get db config by name"""
"""Return db connect info by name."""
session = self.get_raw_session()
if db_name:
select_statement = text(
"""
SELECT
SELECT
*
FROM
connect_config
WHERE
FROM
connect_config
WHERE
db_name = :db_name
"""
)
@@ -196,7 +187,7 @@ class ConnectConfigDao(BaseDao):
else:
raise ValueError("Cannot get database by name" + db_name)
print(result)
logger.info(f"Result: {result}")
fields = [field[0] for field in result.cursor.description]
row_dict = {}
row_1 = list(result.cursor.fetchall()[0])
@@ -205,7 +196,7 @@ class ConnectConfigDao(BaseDao):
return row_dict
def get_db_list(self):
"""get db list"""
"""Get db list."""
session = self.get_raw_session()
result = session.execute(text("SELECT * FROM connect_config"))
@@ -219,7 +210,7 @@ class ConnectConfigDao(BaseDao):
return data
def delete_db(self, db_name):
"""delete db connect info"""
"""Delete db connect info."""
session = self.get_raw_session()
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
params = {"db_name": db_name}