mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
fix(core): Fix not create database bug (#946)
This commit is contained in:
parent
27536f72ad
commit
bdf9442393
@ -2,6 +2,7 @@ import signal
|
||||
import os
|
||||
import threading
|
||||
import sys
|
||||
import logging
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@ -14,6 +15,8 @@ from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
print("in order to avoid chroma db atexit problem")
|
||||
@ -110,6 +113,8 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||
os.makedirs(default_meta_data_path, exist_ok=True)
|
||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
|
||||
# Try to create database, if failed, will raise exception
|
||||
_create_mysql_database(db_name, db_url, try_to_create_db)
|
||||
else:
|
||||
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
|
||||
db_url = f"sqlite:///{sqlite_db_path}"
|
||||
@ -124,6 +129,48 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
||||
return default_meta_data_path
|
||||
|
||||
|
||||
def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = False):
|
||||
"""Create mysql database if not exists
|
||||
|
||||
Args:
|
||||
db_name (str): The database name
|
||||
db_url (str): The database url, include host, port, user, password and database name
|
||||
try_to_create_db (bool, optional): Whether to try to create database. Defaults to False.
|
||||
|
||||
Raises:
|
||||
Exception: Raise exception if database operation failed
|
||||
"""
|
||||
from sqlalchemy import create_engine, DDL
|
||||
from sqlalchemy.exc import SQLAlchemyError, OperationalError
|
||||
|
||||
if not try_to_create_db:
|
||||
logger.info(f"Skipping creation of database {db_name}")
|
||||
return
|
||||
engine = create_engine(db_url)
|
||||
|
||||
try:
|
||||
# Try to connect to the database
|
||||
with engine.connect() as conn:
|
||||
logger.info(f"Database {db_name} already exists")
|
||||
return
|
||||
except OperationalError as oe:
|
||||
# If the error indicates that the database does not exist, try to create it
|
||||
if "Unknown database" in str(oe):
|
||||
try:
|
||||
# Create the database
|
||||
no_db_name_url = db_url.rsplit("/", 1)[0]
|
||||
engine_no_db = create_engine(no_db_name_url)
|
||||
with engine_no_db.connect() as conn:
|
||||
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
|
||||
logger.info(f"Database {db_name} successfully created")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Failed to create database {db_name}: {e}")
|
||||
raise
|
||||
else:
|
||||
logger.error(f"Error connecting to database {db_name}: {oe}")
|
||||
raise
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebServerParameters(BaseParameters):
|
||||
host: Optional[str] = field(
|
||||
|
0
dbgpt/app/tests/__init__.py
Normal file
0
dbgpt/app/tests/__init__.py
Normal file
71
dbgpt/app/tests/test_base.py
Normal file
71
dbgpt/app/tests/test_base.py
Normal file
@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||
|
||||
from dbgpt.app.base import _create_mysql_database
|
||||
|
||||
|
||||
@patch("sqlalchemy.create_engine")
|
||||
@patch("dbgpt.app.base.logger")
|
||||
def test_database_already_exists(mock_logger, mock_create_engine):
|
||||
mock_connection = MagicMock()
|
||||
mock_create_engine.return_value.connect.return_value.__enter__.return_value = (
|
||||
mock_connection
|
||||
)
|
||||
|
||||
_create_mysql_database(
|
||||
"test_db", "mysql+pymysql://user:password@host/test_db", True
|
||||
)
|
||||
mock_logger.info.assert_called_with("Database test_db already exists")
|
||||
mock_connection.execute.assert_not_called()
|
||||
|
||||
|
||||
@patch("sqlalchemy.create_engine")
|
||||
@patch("dbgpt.app.base.logger")
|
||||
def test_database_creation_success(mock_logger, mock_create_engine):
|
||||
# Mock the first connection failure, and the second connection success
|
||||
mock_create_engine.side_effect = [
|
||||
MagicMock(
|
||||
connect=MagicMock(
|
||||
side_effect=OperationalError("Unknown database", None, None)
|
||||
)
|
||||
),
|
||||
MagicMock(),
|
||||
]
|
||||
|
||||
_create_mysql_database(
|
||||
"test_db", "mysql+pymysql://user:password@host/test_db", True
|
||||
)
|
||||
mock_logger.info.assert_called_with("Database test_db successfully created")
|
||||
|
||||
|
||||
@patch("sqlalchemy.create_engine")
|
||||
@patch("dbgpt.app.base.logger")
|
||||
def test_database_creation_failure(mock_logger, mock_create_engine):
|
||||
# Mock the first connection failure, and the second connection failure with SQLAlchemyError
|
||||
mock_create_engine.side_effect = [
|
||||
MagicMock(
|
||||
connect=MagicMock(
|
||||
side_effect=OperationalError("Unknown database", None, None)
|
||||
)
|
||||
),
|
||||
MagicMock(connect=MagicMock(side_effect=SQLAlchemyError("Creation failed"))),
|
||||
]
|
||||
|
||||
with pytest.raises(SQLAlchemyError):
|
||||
_create_mysql_database(
|
||||
"test_db", "mysql+pymysql://user:password@host/test_db", True
|
||||
)
|
||||
mock_logger.error.assert_called_with(
|
||||
"Failed to create database test_db: Creation failed"
|
||||
)
|
||||
|
||||
|
||||
@patch("sqlalchemy.create_engine")
|
||||
@patch("dbgpt.app.base.logger")
|
||||
def test_skip_database_creation(mock_logger, mock_create_engine):
|
||||
_create_mysql_database(
|
||||
"test_db", "mysql+pymysql://user:password@host/test_db", False
|
||||
)
|
||||
mock_logger.info.assert_called_with("Skipping creation of database test_db")
|
||||
mock_create_engine.assert_not_called()
|
Loading…
Reference in New Issue
Block a user