fix(core): Fix not create database bug (#946)

This commit is contained in:
Fangyin Cheng 2023-12-16 20:59:54 +08:00 committed by GitHub
parent 27536f72ad
commit bdf9442393
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 118 additions and 0 deletions

View File

@ -2,6 +2,7 @@ import signal
import os
import threading
import sys
import logging
from typing import Optional
from dataclasses import dataclass, field
@ -14,6 +15,8 @@ from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
logger = logging.getLogger(__name__)
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
@ -110,6 +113,8 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
os.makedirs(default_meta_data_path, exist_ok=True)
if CFG.LOCAL_DB_TYPE == "mysql":
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
# Try to create database, if failed, will raise exception
_create_mysql_database(db_name, db_url, try_to_create_db)
else:
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
db_url = f"sqlite:///{sqlite_db_path}"
@ -124,6 +129,48 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
return default_meta_data_path
def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = False):
"""Create mysql database if not exists
Args:
db_name (str): The database name
db_url (str): The database url, include host, port, user, password and database name
try_to_create_db (bool, optional): Whether to try to create database. Defaults to False.
Raises:
Exception: Raise exception if database operation failed
"""
from sqlalchemy import create_engine, DDL
from sqlalchemy.exc import SQLAlchemyError, OperationalError
if not try_to_create_db:
logger.info(f"Skipping creation of database {db_name}")
return
engine = create_engine(db_url)
try:
# Try to connect to the database
with engine.connect() as conn:
logger.info(f"Database {db_name} already exists")
return
except OperationalError as oe:
# If the error indicates that the database does not exist, try to create it
if "Unknown database" in str(oe):
try:
# Create the database
no_db_name_url = db_url.rsplit("/", 1)[0]
engine_no_db = create_engine(no_db_name_url)
with engine_no_db.connect() as conn:
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
logger.info(f"Database {db_name} successfully created")
except SQLAlchemyError as e:
logger.error(f"Failed to create database {db_name}: {e}")
raise
else:
logger.error(f"Error connecting to database {db_name}: {oe}")
raise
@dataclass
class WebServerParameters(BaseParameters):
host: Optional[str] = field(

View File

View File

@ -0,0 +1,71 @@
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from dbgpt.app.base import _create_mysql_database
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_already_exists(mock_logger, mock_create_engine):
mock_connection = MagicMock()
mock_create_engine.return_value.connect.return_value.__enter__.return_value = (
mock_connection
)
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.info.assert_called_with("Database test_db already exists")
mock_connection.execute.assert_not_called()
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_creation_success(mock_logger, mock_create_engine):
# Mock the first connection failure, and the second connection success
mock_create_engine.side_effect = [
MagicMock(
connect=MagicMock(
side_effect=OperationalError("Unknown database", None, None)
)
),
MagicMock(),
]
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.info.assert_called_with("Database test_db successfully created")
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_creation_failure(mock_logger, mock_create_engine):
# Mock the first connection failure, and the second connection failure with SQLAlchemyError
mock_create_engine.side_effect = [
MagicMock(
connect=MagicMock(
side_effect=OperationalError("Unknown database", None, None)
)
),
MagicMock(connect=MagicMock(side_effect=SQLAlchemyError("Creation failed"))),
]
with pytest.raises(SQLAlchemyError):
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.error.assert_called_with(
"Failed to create database test_db: Creation failed"
)
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_skip_database_creation(mock_logger, mock_create_engine):
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", False
)
mock_logger.info.assert_called_with("Skipping creation of database test_db")
mock_create_engine.assert_not_called()