DB-GPT/tests/intetration_tests/datasource/test_conn_mysql.py
2024-03-18 18:06:40 +08:00

93 lines
2.3 KiB
Python

"""
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_mysql.py
docker run -itd --name mysql-test -p 3307:3306 -e MYSQL_ROOT_PASSWORD=12345678 mysql:5.7
mysql -h 127.0.0.1 -uroot -p -P3307
Enter password:
Welcome to the MySQL monitor. Commands end with ; or \g.
Your MySQL connection id is 2
Server version: 5.7.41 MySQL Community Server (GPL)
Copyright (c) 2000, 2023, Oracle and/or its affiliates.
Oracle is a registered trademark of Oracle Corporation and/or its
affiliates. Other names may be trademarks of their respective
owners.
Type 'help;' or '\h' for help. Type '\c' to clear the current input statement.
> create database test;
"""
import pytest
from dbgpt.datasource.rdbms.conn_mysql import MySQLConnector
_create_table_sql = """
CREATE TABLE IF NOT EXISTS `test` (
`id` int(11) DEFAULT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
"""
@pytest.fixture
def db():
conn = MySQLConnector.from_uri_db(
"localhost",
3307,
"root",
"********",
"test",
engine_args={"connect_args": {"charset": "utf8mb4"}},
)
yield conn
def test_get_usable_table_names(db):
db.run(_create_table_sql)
print(db._sync_tables_from_db())
assert list(db.get_usable_table_names()) == []
def test_get_table_info(db):
assert "CREATE TABLE test" in db.get_table_info()
def test_get_table_info_with_table(db):
db.run(_create_table_sql)
print(db._sync_tables_from_db())
table_info = db.get_table_info()
assert "CREATE TABLE test" in table_info
def test_run_no_throw(db):
assert db.run_no_throw("this is a error sql").startswith("Error:")
def test_get_index_empty(db):
db.run(_create_table_sql)
assert db.get_indexes("test") == []
def test_get_fields(db):
db.run(_create_table_sql)
assert list(db.get_fields("test")[0])[0] == "id"
def test_get_charset(db):
assert db.get_charset() == "utf8mb4" or db.get_charset() == "latin1"
def test_get_collation(db):
assert (
db.get_collation() == "utf8mb4_general_ci"
or db.get_collation() == "latin1_swedish_ci"
)
def test_get_users(db):
assert ("root", "%") in db.get_users()
def test_get_database_lists(db):
assert db.get_database_names() == ["test"]