mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 06:47:30 +00:00
fix(core): Fix not create database bug (#946)
This commit is contained in:
parent
27536f72ad
commit
bdf9442393
@ -2,6 +2,7 @@ import signal
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
@ -14,6 +15,8 @@ from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
|
|||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
print("in order to avoid chroma db atexit problem")
|
print("in order to avoid chroma db atexit problem")
|
||||||
@ -110,6 +113,8 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
|||||||
os.makedirs(default_meta_data_path, exist_ok=True)
|
os.makedirs(default_meta_data_path, exist_ok=True)
|
||||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||||
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
|
db_url = f"mysql+pymysql://{quote(CFG.LOCAL_DB_USER)}:{urlquote(CFG.LOCAL_DB_PASSWORD)}@{CFG.LOCAL_DB_HOST}:{str(CFG.LOCAL_DB_PORT)}/{db_name}"
|
||||||
|
# Try to create database, if failed, will raise exception
|
||||||
|
_create_mysql_database(db_name, db_url, try_to_create_db)
|
||||||
else:
|
else:
|
||||||
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
|
sqlite_db_path = os.path.join(default_meta_data_path, f"{db_name}.db")
|
||||||
db_url = f"sqlite:///{sqlite_db_path}"
|
db_url = f"sqlite:///{sqlite_db_path}"
|
||||||
@ -124,6 +129,48 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
|
|||||||
return default_meta_data_path
|
return default_meta_data_path
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mysql_database(db_name: str, db_url: str, try_to_create_db: bool = False):
|
||||||
|
"""Create mysql database if not exists
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_name (str): The database name
|
||||||
|
db_url (str): The database url, include host, port, user, password and database name
|
||||||
|
try_to_create_db (bool, optional): Whether to try to create database. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: Raise exception if database operation failed
|
||||||
|
"""
|
||||||
|
from sqlalchemy import create_engine, DDL
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError, OperationalError
|
||||||
|
|
||||||
|
if not try_to_create_db:
|
||||||
|
logger.info(f"Skipping creation of database {db_name}")
|
||||||
|
return
|
||||||
|
engine = create_engine(db_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to connect to the database
|
||||||
|
with engine.connect() as conn:
|
||||||
|
logger.info(f"Database {db_name} already exists")
|
||||||
|
return
|
||||||
|
except OperationalError as oe:
|
||||||
|
# If the error indicates that the database does not exist, try to create it
|
||||||
|
if "Unknown database" in str(oe):
|
||||||
|
try:
|
||||||
|
# Create the database
|
||||||
|
no_db_name_url = db_url.rsplit("/", 1)[0]
|
||||||
|
engine_no_db = create_engine(no_db_name_url)
|
||||||
|
with engine_no_db.connect() as conn:
|
||||||
|
conn.execute(DDL(f"CREATE DATABASE {db_name}"))
|
||||||
|
logger.info(f"Database {db_name} successfully created")
|
||||||
|
except SQLAlchemyError as e:
|
||||||
|
logger.error(f"Failed to create database {db_name}: {e}")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
logger.error(f"Error connecting to database {db_name}: {oe}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WebServerParameters(BaseParameters):
|
class WebServerParameters(BaseParameters):
|
||||||
host: Optional[str] = field(
|
host: Optional[str] = field(
|
||||||
|
0
dbgpt/app/tests/__init__.py
Normal file
0
dbgpt/app/tests/__init__.py
Normal file
71
dbgpt/app/tests/test_base.py
Normal file
71
dbgpt/app/tests/test_base.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from sqlalchemy.exc import OperationalError, SQLAlchemyError
|
||||||
|
|
||||||
|
from dbgpt.app.base import _create_mysql_database
|
||||||
|
|
||||||
|
|
||||||
|
@patch("sqlalchemy.create_engine")
|
||||||
|
@patch("dbgpt.app.base.logger")
|
||||||
|
def test_database_already_exists(mock_logger, mock_create_engine):
|
||||||
|
mock_connection = MagicMock()
|
||||||
|
mock_create_engine.return_value.connect.return_value.__enter__.return_value = (
|
||||||
|
mock_connection
|
||||||
|
)
|
||||||
|
|
||||||
|
_create_mysql_database(
|
||||||
|
"test_db", "mysql+pymysql://user:password@host/test_db", True
|
||||||
|
)
|
||||||
|
mock_logger.info.assert_called_with("Database test_db already exists")
|
||||||
|
mock_connection.execute.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("sqlalchemy.create_engine")
|
||||||
|
@patch("dbgpt.app.base.logger")
|
||||||
|
def test_database_creation_success(mock_logger, mock_create_engine):
|
||||||
|
# Mock the first connection failure, and the second connection success
|
||||||
|
mock_create_engine.side_effect = [
|
||||||
|
MagicMock(
|
||||||
|
connect=MagicMock(
|
||||||
|
side_effect=OperationalError("Unknown database", None, None)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
MagicMock(),
|
||||||
|
]
|
||||||
|
|
||||||
|
_create_mysql_database(
|
||||||
|
"test_db", "mysql+pymysql://user:password@host/test_db", True
|
||||||
|
)
|
||||||
|
mock_logger.info.assert_called_with("Database test_db successfully created")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("sqlalchemy.create_engine")
|
||||||
|
@patch("dbgpt.app.base.logger")
|
||||||
|
def test_database_creation_failure(mock_logger, mock_create_engine):
|
||||||
|
# Mock the first connection failure, and the second connection failure with SQLAlchemyError
|
||||||
|
mock_create_engine.side_effect = [
|
||||||
|
MagicMock(
|
||||||
|
connect=MagicMock(
|
||||||
|
side_effect=OperationalError("Unknown database", None, None)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
MagicMock(connect=MagicMock(side_effect=SQLAlchemyError("Creation failed"))),
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(SQLAlchemyError):
|
||||||
|
_create_mysql_database(
|
||||||
|
"test_db", "mysql+pymysql://user:password@host/test_db", True
|
||||||
|
)
|
||||||
|
mock_logger.error.assert_called_with(
|
||||||
|
"Failed to create database test_db: Creation failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("sqlalchemy.create_engine")
|
||||||
|
@patch("dbgpt.app.base.logger")
|
||||||
|
def test_skip_database_creation(mock_logger, mock_create_engine):
|
||||||
|
_create_mysql_database(
|
||||||
|
"test_db", "mysql+pymysql://user:password@host/test_db", False
|
||||||
|
)
|
||||||
|
mock_logger.info.assert_called_with("Skipping creation of database test_db")
|
||||||
|
mock_create_engine.assert_not_called()
|
Loading…
Reference in New Issue
Block a user