mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 09:28:42 +00:00
85 lines
3.2 KiB
Python
85 lines
3.2 KiB
Python
import logging
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from dbgpt._private.config import Config
|
|
from dbgpt.storage.schema import DBType
|
|
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
|
|
|
logger = logging.getLogger(__name__)
|
|
CFG = Config()
|
|
|
|
|
|
class BaseDao:
|
|
def __init__(
|
|
self, orm_base=None, database: str = None, create_not_exist_table: bool = False
|
|
) -> None:
|
|
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
|
self._orm_base = orm_base
|
|
self._database = database
|
|
self._create_not_exist_table = create_not_exist_table
|
|
|
|
self._db_engine = None
|
|
self._session = None
|
|
self._connection = None
|
|
|
|
@property
|
|
def db_engine(self):
|
|
if not self._db_engine:
|
|
# lazy loading
|
|
db_engine, connection = _get_db_engine(
|
|
self._orm_base, self._database, self._create_not_exist_table
|
|
)
|
|
self._db_engine = db_engine
|
|
self._connection = connection
|
|
return self._db_engine
|
|
|
|
@property
|
|
def Session(self):
|
|
if not self._session:
|
|
self._session = sessionmaker(bind=self.db_engine)
|
|
return self._session
|
|
|
|
|
|
def _get_db_engine(
|
|
orm_base=None, database: str = None, create_not_exist_table: bool = False
|
|
):
|
|
db_engine = None
|
|
connection: RDBMSDatabase = None
|
|
|
|
db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
|
|
if db_type is None or db_type == DBType.Mysql:
|
|
# default database
|
|
db_engine = create_engine(
|
|
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
|
echo=True,
|
|
)
|
|
else:
|
|
db_namager = CFG.LOCAL_DB_MANAGE
|
|
if not db_namager:
|
|
raise Exception(
|
|
"LOCAL_DB_MANAGE is not initialized, please check the system configuration"
|
|
)
|
|
if db_type.is_file_db():
|
|
db_path = CFG.LOCAL_DB_PATH
|
|
if db_path is None or db_path == "":
|
|
raise ValueError(
|
|
"You LOCAL_DB_TYPE is file db, but LOCAL_DB_PATH is not configured, please configure LOCAL_DB_PATH in you .env file"
|
|
)
|
|
_, database = db_namager._parse_file_db_info(db_type.value(), db_path)
|
|
logger.info(
|
|
f"Current DAO database is file database, db_type: {db_type.value()}, db_path: {db_path}, db_name: {database}"
|
|
)
|
|
logger.info(f"Get DAO database connection with database name {database}")
|
|
connection: RDBMSDatabase = db_namager.get_connect(database)
|
|
if not isinstance(connection, RDBMSDatabase):
|
|
raise ValueError(
|
|
"Currently only supports `RDBMSDatabase` database as the underlying database of BaseDao, please check your database configuration"
|
|
)
|
|
db_engine = connection._engine
|
|
|
|
if db_type.is_file_db() and orm_base is not None and create_not_exist_table:
|
|
logger.info("Current database is file database, create not exist table")
|
|
orm_base.metadata.create_all(db_engine)
|
|
|
|
return db_engine, connection
|