1
0
mirror of https://github.com/haiwen/seahub.git synced 2025-04-28 03:10:45 +00:00
seahub/scripts/upgrade/fix_mysql_user.py
2021-12-04 10:45:31 +08:00

235 lines
6.6 KiB
Python

#!/usr/bin/env python
import os
import sys
import re
import configparser
import getpass
from collections import namedtuple
try:
import pymysql
HAS_PYMYSQL = True
except ImportError:
HAS_PYMYSQL = False
MySQLDBInfo = namedtuple('MySQLDBInfo', 'host port username password db')
class EnvManager(object):
def __init__(self):
self.upgrade_dir = os.path.abspath(os.path.dirname(__file__))
self.install_path = os.path.dirname(self.upgrade_dir)
self.top_dir = os.path.dirname(self.install_path)
self.ccnet_dir = os.environ['CCNET_CONF_DIR']
self.seafile_dir = os.environ['SEAFILE_CONF_DIR']
env_mgr = EnvManager()
class Utils(object):
@staticmethod
def highlight(content, is_error=False):
'''Add ANSI color to content to get it highlighted on terminal'''
if is_error:
return '\x1b[1;31m%s\x1b[m' % content
else:
return '\x1b[1;32m%s\x1b[m' % content
@staticmethod
def info(msg):
print(Utils.highlight('[INFO] ') + msg)
@staticmethod
def error(msg):
print(Utils.highlight('[ERROR] ') + msg)
sys.exit(1)
@staticmethod
def read_config(config_path, defaults):
cp = configparser.ConfigParser(defaults)
cp.read(config_path)
return cp
def get_ccnet_mysql_info():
ccnet_conf = os.path.join(env_mgr.ccnet_dir, 'ccnet.conf')
defaults = {
'HOST': '127.0.0.1',
'PORT': '3306',
}
config = Utils.read_config(ccnet_conf, defaults)
db_section = 'Database'
if not config.has_section(db_section):
return None
type = config.get(db_section, 'ENGINE')
if type != 'mysql':
return None
try:
host = config.get(db_section, 'HOST')
port = config.getint(db_section, 'PORT')
username = config.get(db_section, 'USER')
password = config.get(db_section, 'PASSWD')
db = config.get(db_section, 'DB')
except configparser.NoOptionError as e:
Utils.error('Database config in ccnet.conf is invalid: %s' % e)
info = MySQLDBInfo(host, port, username, password, db)
return info
def get_seafile_mysql_info():
seafile_conf = os.path.join(env_mgr.seafile_dir, 'seafile.conf')
defaults = {
'HOST': '127.0.0.1',
'PORT': '3306',
}
config = Utils.read_config(seafile_conf, defaults)
db_section = 'database'
if not config.has_section(db_section):
return None
type = config.get(db_section, 'type')
if type != 'mysql':
return None
try:
host = config.get(db_section, 'host')
port = config.getint(db_section, 'port')
username = config.get(db_section, 'user')
password = config.get(db_section, 'password')
db = config.get(db_section, 'db_name')
except configparser.NoOptionError as e:
Utils.error('Database config in seafile.conf is invalid: %s' % e)
info = MySQLDBInfo(host, port, username, password, db)
return info
def get_seahub_mysql_info():
sys.path.insert(0, env_mgr.top_dir)
try:
import seahub_settings# pylint: disable=F0401
except ImportError as e:
Utils.error('Failed to import seahub_settings.py: %s' % e)
if not hasattr(seahub_settings, 'DATABASES'):
return None
try:
d = seahub_settings.DATABASES['default']
if d['ENGINE'] != 'django.db.backends.mysql':
return None
host = d.get('HOST', '127.0.0.1')
port = int(d.get('PORT', 3306))
username = d['USER']
password = d['PASSWORD']
db = d['NAME']
except KeyError:
Utils.error('Database config in seahub_settings.py is invalid: %s' % e)
info = MySQLDBInfo(host, port, username, password, db)
return info
def get_seafile_db_infos():
ccnet_db_info = get_ccnet_mysql_info()
seafile_db_info = get_seafile_mysql_info()
seahub_db_info = get_seahub_mysql_info()
infos = [ccnet_db_info, seafile_db_info, seahub_db_info]
for info in infos:
if info is None:
return None
if info.host not in ('localhost', '127.0.0.1'):
return None
return infos
def ask_root_password(port):
while True:
desc = 'What is the root password for mysql? '
password = getpass.getpass(desc).strip()
if password:
try:
return check_mysql_user('root', password, port)
except InvalidAnswer as e:
print('\n%s\n' % e)
continue
class InvalidAnswer(Exception):
def __init__(self, msg):
Exception.__init__(self)
self.msg = msg
def __str__(self):
return self.msg
def check_mysql_user(user, password, port):
print('\nverifying password of root user %s ... ' % user, end=' ')
kwargs = dict(host='localhost',
port=port,
user=user,
passwd=password)
try:
conn = pymysql.connect(**kwargs)
except Exception as e:
if isinstance(e, pymysql.err.OperationalError):
raise InvalidAnswer('Failed to connect to mysql server using user "%s" and password "***": %s'
% (user, e.args[1]))
else:
raise InvalidAnswer('Failed to connect to mysql server using user "%s" and password "***": %s'
% (user, e))
print('done')
return conn
def apply_fix(root_conn, user, dbs):
for db in dbs:
grant_db_permission(root_conn, user, db)
cursor = root_conn.cursor()
sql = """
SELECT *
FROM mysql.user
WHERE Host = '%%'
AND password = ''
AND User = '%s'
""" % user
cursor.execute(sql)
if cursor.rowcount > 0:
sql = 'DROP USER `%s`@`%%`' % user
cursor.execute(sql)
def grant_db_permission(conn, user, db):
cursor = conn.cursor()
sql = '''GRANT ALL PRIVILEGES ON `%s`.* to `%s`@localhost ''' \
% (db, user)
try:
cursor.execute(sql)
except Exception as e:
if isinstance(e, pymysql.err.OperationalError):
Utils.error('Failed to grant permission of database %s: %s' % (db, e.args[1]))
else:
Utils.error('Failed to grant permission of database %s: %s' % (db, e))
finally:
cursor.close()
def main():
dbinfos = get_seafile_db_infos()
if not dbinfos:
return
if dbinfos[0].username == 'root':
return
if not HAS_PYMYSQL:
Utils.error('Python pymysql module is not found')
root_conn = ask_root_password(dbinfos[0].port)
apply_fix(root_conn, dbinfos[0].username, [info.db for info in dbinfos])
if __name__ == '__main__':
main()