fix(core): Fix not create database bug (#946)

This commit is contained in:
Fangyin Cheng 2023-12-16 20:59:54 +08:00 committed by GitHub
parent 27536f72ad
commit bdf9442393
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 118 additions and 0 deletions

View File

@ -2,6 +2,7 @@ import signal
import os import os
import threading import threading
import sys import sys
import logging
from typing import Optional from typing import Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -14,6 +15,8 @@ from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
logger = logging.getLogger(__name__)
def signal_handler(sig, frame): def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem") print("in order to avoid chroma db atexit problem")
@ -110,6 +113,8 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
os.makedirs(default_meta_data_path, exist_ok=True) os.makedirs(default_meta_data_path, exist_ok=True)
if CFG.LOCAL_DB_TYPE == "mysql": if CFG.LOCAL_DB_TYPE == "mysql":
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}" db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
# Try to create database, if failed, will raise exception
_create_mysql_database(db_name, db_url, try_to_create_db)
else: else:
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db") sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
db_url = f"sqlite:///{sqlite_db_path}" db_url = f"sqlite:///{sqlite_db_path}"
@ -124,6 +129,48 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
return default_meta_data_path return default_meta_data_path
def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = False):
"""Create mysql database if not exists
Args:
db_name (str): The database name
db_url (str): The database url, include host, port, user, password and database name
try_to_create_db (bool, optional): Whether to try to create database. Defaults to False.
Raises:
Exception: Raise exception if database operation failed
"""
from sqlalchemy import create_engine, DDL
from sqlalchemy.exc import SQLAlchemyError, OperationalError
if not try_to_create_db:
logger.info(f"Skipping creation of database {db_name}")
return
engine = create_engine(db_url)
try:
# Try to connect to the database
with engine.connect() as conn:
logger.info(f"Database {db_name} already exists")
return
except OperationalError as oe:
# If the error indicates that the database does not exist, try to create it
if "Unknown database" in str(oe):
try:
# Create the database
no_db_name_url = db_url.rsplit("/", 1)[0]
engine_no_db = create_engine(no_db_name_url)
with engine_no_db.connect() as conn:
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
logger.info(f"Database {db_name} successfully created")
except SQLAlchemyError as e:
logger.error(f"Failed to create database {db_name}: {e}")
raise
else:
logger.error(f"Error connecting to database {db_name}: {oe}")
raise
@dataclass @dataclass
class WebServerParameters(BaseParameters): class WebServerParameters(BaseParameters):
host: Optional[str] = field( host: Optional[str] = field(

View File

View File

@ -0,0 +1,71 @@
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from dbgpt.app.base import _create_mysql_database
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_already_exists(mock_logger, mock_create_engine):
mock_connection = MagicMock()
mock_create_engine.return_value.connect.return_value.__enter__.return_value = (
mock_connection
)
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.info.assert_called_with("Database test_db already exists")
mock_connection.execute.assert_not_called()
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_creation_success(mock_logger, mock_create_engine):
# Mock the first connection failure, and the second connection success
mock_create_engine.side_effect = [
MagicMock(
connect=MagicMock(
side_effect=OperationalError("Unknown database", None, None)
)
),
MagicMock(),
]
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.info.assert_called_with("Database test_db successfully created")
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_creation_failure(mock_logger, mock_create_engine):
# Mock the first connection failure, and the second connection failure with SQLAlchemyError
mock_create_engine.side_effect = [
MagicMock(
connect=MagicMock(
side_effect=OperationalError("Unknown database", None, None)
)
),
MagicMock(connect=MagicMock(side_effect=SQLAlchemyError("Creation failed"))),
]
with pytest.raises(SQLAlchemyError):
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.error.assert_called_with(
"Failed to create database test_db: Creation failed"
)
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_skip_database_creation(mock_logger, mock_create_engine):
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", False
)
mock_logger.info.assert_called_with("Skipping creation of database test_db")
mock_create_engine.assert_not_called()