mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
refactor: Refactor datasource module (#1309)
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
"""DB Model for connect_config."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, Index, Integer, String, Text, UniqueConstraint, text
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectConfigEntity(Model):
|
||||
"""db connect config entity"""
|
||||
"""DB connector config entity."""
|
||||
|
||||
__tablename__ = "connect_config"
|
||||
id = Column(
|
||||
@@ -28,32 +35,10 @@ class ConnectConfigEntity(Model):
|
||||
|
||||
|
||||
class ConnectConfigDao(BaseDao):
|
||||
"""db connect config dao"""
|
||||
"""DB connector config dao."""
|
||||
|
||||
def update(self, entity: ConnectConfigEntity):
|
||||
"""update db connect info"""
|
||||
session = self.get_raw_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def delete(self, db_name: int):
|
||||
""" "delete db connect info"""
|
||||
session = self.get_raw_session()
|
||||
if db_name is None:
|
||||
raise Exception("db_name is None")
|
||||
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
db_connect.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
|
||||
"""get db connect info by name"""
|
||||
def get_by_names(self, db_name: str) -> Optional[ConnectConfigEntity]:
|
||||
"""Get db connect info by name."""
|
||||
session = self.get_raw_session()
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
@@ -71,8 +56,8 @@ class ConnectConfigDao(BaseDao):
|
||||
db_pwd: str,
|
||||
comment: str = "",
|
||||
):
|
||||
"""
|
||||
add db connect info
|
||||
"""Add db connect info.
|
||||
|
||||
Args:
|
||||
db_name: db name
|
||||
db_type: db type
|
||||
@@ -90,9 +75,9 @@ class ConnectConfigDao(BaseDao):
|
||||
insert_statement = text(
|
||||
"""
|
||||
INSERT INTO connect_config (
|
||||
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
|
||||
) VALUES (
|
||||
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
|
||||
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd,
|
||||
comment) VALUES (:db_name, :db_type, :db_path, :db_host, :db_port
|
||||
, :db_user, :db_pwd, :comment
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -111,7 +96,7 @@ class ConnectConfigDao(BaseDao):
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("add db connect info error!" + str(e))
|
||||
logger.warning("add db connect info error!" + str(e))
|
||||
|
||||
def update_db_info(
|
||||
self,
|
||||
@@ -124,37 +109,43 @@ class ConnectConfigDao(BaseDao):
|
||||
db_pwd: str = "",
|
||||
comment: str = "",
|
||||
):
|
||||
"""update db connect info"""
|
||||
"""Update db connect info."""
|
||||
old_db_conf = self.get_db_config(db_name)
|
||||
if old_db_conf:
|
||||
try:
|
||||
session = self.get_raw_session()
|
||||
if not db_path:
|
||||
update_statement = text(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
||||
f"UPDATE connect_config set db_type='{db_type}', "
|
||||
f"db_host='{db_host}', db_port={db_port}, db_user='{db_user}', "
|
||||
f"db_pwd='{db_pwd}', comment='{comment}' where "
|
||||
f"db_name='{db_name}'"
|
||||
)
|
||||
else:
|
||||
update_statement = text(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
|
||||
f"UPDATE connect_config set db_type='{db_type}', "
|
||||
f"db_path='{db_path}', comment='{comment}' where "
|
||||
f"db_name='{db_name}'"
|
||||
)
|
||||
session.execute(update_statement)
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("edit db connect info error!" + str(e))
|
||||
logger.warning("edit db connect info error!" + str(e))
|
||||
return True
|
||||
raise ValueError(f"{db_name} not have config info!")
|
||||
|
||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||
"""add file db connect info"""
|
||||
"""Add file db connect info."""
|
||||
try:
|
||||
session = self.get_raw_session()
|
||||
insert_statement = text(
|
||||
"""
|
||||
INSERT INTO connect_config(
|
||||
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
|
||||
) VALUES (
|
||||
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
|
||||
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd,
|
||||
comment) VALUES (
|
||||
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd
|
||||
, :comment
|
||||
)
|
||||
"""
|
||||
)
|
||||
@@ -174,19 +165,19 @@ class ConnectConfigDao(BaseDao):
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("add db connect info error!" + str(e))
|
||||
logger.warning("add db connect info error!" + str(e))
|
||||
|
||||
def get_db_config(self, db_name):
|
||||
"""get db config by name"""
|
||||
"""Return db connect info by name."""
|
||||
session = self.get_raw_session()
|
||||
if db_name:
|
||||
select_statement = text(
|
||||
"""
|
||||
SELECT
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
connect_config
|
||||
WHERE
|
||||
FROM
|
||||
connect_config
|
||||
WHERE
|
||||
db_name = :db_name
|
||||
"""
|
||||
)
|
||||
@@ -196,7 +187,7 @@ class ConnectConfigDao(BaseDao):
|
||||
else:
|
||||
raise ValueError("Cannot get database by name" + db_name)
|
||||
|
||||
print(result)
|
||||
logger.info(f"Result: {result}")
|
||||
fields = [field[0] for field in result.cursor.description]
|
||||
row_dict = {}
|
||||
row_1 = list(result.cursor.fetchall()[0])
|
||||
@@ -205,7 +196,7 @@ class ConnectConfigDao(BaseDao):
|
||||
return row_dict
|
||||
|
||||
def get_db_list(self):
|
||||
"""get db list"""
|
||||
"""Get db list."""
|
||||
session = self.get_raw_session()
|
||||
result = session.execute(text("SELECT * FROM connect_config"))
|
||||
|
||||
@@ -219,7 +210,7 @@ class ConnectConfigDao(BaseDao):
|
||||
return data
|
||||
|
||||
def delete_db(self, db_name):
|
||||
"""delete db connect info"""
|
||||
"""Delete db connect info."""
|
||||
session = self.get_raw_session()
|
||||
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
|
||||
params = {"db_name": db_name}
|
||||
|
Reference in New Issue
Block a user