DB-GPT/dbgpt/datasource/manages/connect_config_db.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

299 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""DB Model for connect_config."""
import logging
from typing import Any, Dict, Optional, Union
from sqlalchemy import Column, Index, Integer, String, Text, UniqueConstraint, text
from dbgpt.serve.datasource.api.schemas import (
DatasourceServeRequest,
DatasourceServeResponse,
)
from dbgpt.storage.metadata import BaseDao, Model
logger = logging.getLogger(__name__)
class ConnectConfigEntity(Model):
"""DB connector config entity."""
__tablename__ = "connect_config"
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
db_type = Column(String(255), nullable=False, comment="db type")
db_name = Column(String(255), nullable=False, comment="db name")
db_path = Column(String(255), nullable=True, comment="file db path")
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
db_port = Column(String(255), nullable=True, comment="db connect port(not file db)")
db_user = Column(String(255), nullable=True, comment="db user")
db_pwd = Column(String(255), nullable=True, comment="db password")
comment = Column(Text, nullable=True, comment="db comment")
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
user_id = Column(String(128), index=True, nullable=True, comment="User id")
user_name = Column(String(128), index=True, nullable=True, comment="User name")
__table_args__ = (
UniqueConstraint("db_name", name="uk_db"),
Index("idx_q_db_type", "db_type"),
)
class ConnectConfigDao(BaseDao):
"""DB connector config dao."""
def get_by_names(self, db_name: str) -> Optional[ConnectConfigEntity]:
"""Get db connect info by name."""
session = self.get_raw_session()
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
result = db_connect.first()
session.close()
return result
def add_url_db(
self,
db_name,
db_type,
db_host: str,
db_port: int,
db_user: str,
db_pwd: str,
comment: Optional[str] = None,
user_id: Optional[str] = None,
):
"""Add db connect info.
Args:
db_name: db name
db_type: db type
db_host: db host
db_port: db port
db_user: db user
db_pwd: db password
comment: comment
"""
try:
session = self.get_raw_session()
from sqlalchemy import text
insert_statement = text(
"""
INSERT INTO connect_config (
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd,
comment, user_id) VALUES (:db_name, :db_type, :db_path, :db_host,
:db_port, :db_user, :db_pwd, :comment, :user_id
)
"""
)
params = {
"db_name": db_name,
"db_type": db_type,
"db_path": "",
"db_host": db_host,
"db_port": db_port,
"db_user": db_user,
"db_pwd": db_pwd,
"comment": comment if comment else "",
"user_id": user_id if user_id else "",
}
session.execute(insert_statement, params)
session.commit()
session.close()
except Exception as e:
logger.warning("add db connect info error" + str(e))
def add_file_db(
self,
db_name,
db_type,
db_path: str,
comment: Optional[str] = None,
user_id: Optional[str] = None,
):
"""Add file db connect info."""
try:
session = self.get_raw_session()
insert_statement = text(
"""
INSERT INTO connect_config(
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd,
comment, user_id) VALUES (
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd
, :comment, :user_id
)
"""
)
params = {
"db_name": db_name,
"db_type": db_type,
"db_path": db_path,
"db_host": "",
"db_port": 0,
"db_user": "",
"db_pwd": "",
"comment": comment if comment else "",
"user_id": user_id if user_id else "",
}
session.execute(insert_statement, params)
session.commit()
session.close()
except Exception as e:
logger.warning("add db connect info error" + str(e))
def update_db_info(
self,
db_name,
db_type,
db_path: str = "",
db_host: str = "",
db_port: int = 0,
db_user: str = "",
db_pwd: str = "",
comment: str = "",
):
"""Update db connect info."""
old_db_conf = self.get_db_config(db_name)
if old_db_conf:
try:
session = self.get_raw_session()
if not db_path:
update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', "
f"db_host='{db_host}', db_port={db_port}, db_user='{db_user}', "
f"db_pwd='{db_pwd}', comment='{comment}' where "
f"db_name='{db_name}'"
)
else:
update_statement = text(
f"UPDATE connect_config set db_type='{db_type}', "
f"db_path='{db_path}', comment='{comment}' where "
f"db_name='{db_name}'"
)
session.execute(update_statement)
session.commit()
session.close()
except Exception as e:
logger.warning("edit db connect info error" + str(e))
return True
raise ValueError(f"{db_name} not have config info!")
def get_db_config(self, db_name):
"""Return db connect info by name."""
session = self.get_raw_session()
if db_name:
select_statement = text(
"""
SELECT
*
FROM
connect_config
WHERE
db_name = :db_name
"""
)
params = {"db_name": db_name}
result = session.execute(select_statement, params)
else:
raise ValueError("Cannot get database by name" + db_name)
logger.info(f"Result: {result}")
fields = [field[0] for field in result.cursor.description]
row_dict = {}
row_1 = list(result.cursor.fetchall()[0])
for i, field in enumerate(fields):
row_dict[field] = row_1[i]
return row_dict
def get_db_list(self, db_name: Optional[str] = None, user_id: Optional[str] = None):
"""Get db list."""
session = self.get_raw_session()
if db_name:
sql = f"SELECT * FROM connect_config where (user_id='{user_id}' or user_id='' or user_id IS NULL) and db_name='{db_name}'" # noqa
else:
sql = f"SELECT * FROM connect_config where user_id='{user_id}' or user_id='' or user_id IS NULL" # noqa
result = session.execute(text(sql))
fields = [field[0] for field in result.cursor.description] # type: ignore
data = []
for row in result.cursor.fetchall(): # type: ignore
row_dict = {}
for i, field in enumerate(fields):
row_dict[field] = row[i]
data.append(row_dict)
return data
def delete_db(self, db_name):
"""Delete db connect info."""
session = self.get_raw_session()
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
params = {"db_name": db_name}
session.execute(delete_statement, params)
session.commit()
session.close()
return True
def from_request(
self, request: Union[DatasourceServeRequest, Dict[str, Any]]
) -> ConnectConfigEntity:
"""Convert the request to an entity.
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.dict() if isinstance(request, DatasourceServeRequest) else request
)
entity = ConnectConfigEntity(**request_dict)
return entity
def to_request(self, entity: ConnectConfigEntity) -> DatasourceServeRequest:
"""Convert the entity to a request.
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return DatasourceServeRequest(
id=entity.id,
db_type=entity.db_type,
db_name=entity.db_name,
db_path=entity.db_path,
db_host=entity.db_host,
db_port=entity.db_port,
db_user=entity.db_user,
db_pwd=entity.db_pwd,
comment=entity.comment,
)
def to_response(self, entity: ConnectConfigEntity) -> DatasourceServeResponse:
"""Convert the entity to a response.
Args:
entity (T): The entity
Returns:
REQ: The request
"""
return DatasourceServeResponse(
id=entity.id,
db_type=entity.db_type,
db_name=entity.db_name,
db_path=entity.db_path,
db_host=entity.db_host,
db_port=entity.db_port,
db_user=entity.db_user,
db_pwd=entity.db_pwd,
comment=entity.comment,
)