DB-GPT/dbgpt/app/tests/test_base.py
2024-01-10 10:39:04 +08:00

73 lines
2.4 KiB
Python

from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from dbgpt.app.base import _create_mysql_database
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_already_exists(mock_logger, mock_create_engine):
mock_connection = MagicMock()
mock_create_engine.return_value.connect.return_value.__enter__.return_value = (
mock_connection
)
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.info.assert_called_with("Database test_db already exists")
mock_connection.execute.assert_not_called()
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_creation_success(mock_logger, mock_create_engine):
# Mock the first connection failure, and the second connection success
mock_create_engine.side_effect = [
MagicMock(
connect=MagicMock(
side_effect=OperationalError("Unknown database", None, None)
)
),
MagicMock(),
]
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.info.assert_called_with("Database test_db successfully created")
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_database_creation_failure(mock_logger, mock_create_engine):
# Mock the first connection failure, and the second connection failure with SQLAlchemyError
mock_create_engine.side_effect = [
MagicMock(
connect=MagicMock(
side_effect=OperationalError("Unknown database", None, None)
)
),
MagicMock(connect=MagicMock(side_effect=SQLAlchemyError("Creation failed"))),
]
with pytest.raises(SQLAlchemyError):
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", True
)
mock_logger.error.assert_called_with(
"Failed to create database test_db: Creation failed"
)
@patch("sqlalchemy.create_engine")
@patch("dbgpt.app.base.logger")
def test_skip_database_creation(mock_logger, mock_create_engine):
_create_mysql_database(
"test_db", "mysql+pymysql://user:password@host/test_db", False
)
mock_logger.info.assert_called_with("Skipping creation of database test_db")
mock_create_engine.assert_not_called()