1
0
mirror of https://github.com/haiwen/seahub.git synced 2025-10-25 14:11:07 +00:00
Files
seahub/scripts/upgrade/db_update_helper.py
2025-07-31 11:31:18 +08:00

317 lines
10 KiB
Python

# coding: UTF-8
import sys
import os
import configparser
import glob
HAS_PYMYSQL = True
try:
import pymysql
except ImportError:
HAS_PYMYSQL = False
HAS_SQLITE3 = True
try:
import sqlite3
except ImportError:
HAS_SQLITE3 = False
class EnvManager(object):
def __init__(self):
self.upgrade_dir = os.path.dirname(__file__)
self.install_path = os.path.dirname(self.upgrade_dir)
self.top_dir = os.path.dirname(self.install_path)
self.ccnet_dir = os.environ.get('CCNET_CONF_DIR')
self.seafile_dir = os.environ.get('SEAFILE_CONF_DIR')
self.central_config_dir = os.environ.get('SEAFILE_CENTRAL_CONF_DIR')
self.seafevents_db_dir = os.path.join(os.path.dirname(self.install_path), 'pro-data')
env_mgr = EnvManager()
class Utils(object):
@staticmethod
def highlight(content, is_error=False):
'''Add ANSI color to content to get it highlighted on terminal'''
if is_error:
return '\x1b[1;31m%s\x1b[m' % content
else:
return '\x1b[1;32m%s\x1b[m' % content
@staticmethod
def info(msg):
print(Utils.highlight('[INFO] ') + msg)
@staticmethod
def warning(msg):
print(Utils.highlight('[WARNING] ') + msg)
@staticmethod
def error(msg):
print(Utils.highlight('[ERROR] ') + msg)
sys.exit(1)
@staticmethod
def read_config(config_path, defaults):
if not os.path.exists(config_path):
Utils.error('Config path %s doesn\'t exist, stop db upgrade' %
config_path)
cp = configparser.ConfigParser(defaults)
cp.read(config_path)
return cp
class MySQLDBInfo(object):
def __init__(self, host, port, username, password, db, unix_socket=None):
self.host = host
self.port = port
self.username = username
self.password = password
self.db = db
self.unix_socket = unix_socket
class DBUpdater(object):
def __init__(self, version, name):
self.sql_dir = os.path.join(env_mgr.upgrade_dir, 'sql', version, name)
pro_path = os.path.join(env_mgr.install_path, 'pro')
self.is_pro = os.path.exists(pro_path)
self.version = version
@staticmethod
def get_instance(version):
'''Detect whether we are using mysql or sqlite3'''
ccnet_db_info = DBUpdater.get_ccnet_mysql_info(version)
seafile_db_info = DBUpdater.get_seafile_mysql_info(version)
seahub_db_info = DBUpdater.get_seahub_mysql_info()
if ccnet_db_info and seafile_db_info and seahub_db_info:
Utils.info('You are using MySQL')
if not HAS_PYMYSQL:
Utils.error('Python pymysql module is not found')
updater = MySQLDBUpdater(version, ccnet_db_info, seafile_db_info, seahub_db_info)
elif (ccnet_db_info is None) and (seafile_db_info is None) and (seahub_db_info is None):
Utils.info('You are using SQLite3')
if not HAS_SQLITE3:
Utils.error('Python sqlite3 module is not found')
updater = SQLiteDBUpdater(version)
else:
def to_db_string(info):
if info is None:
return 'SQLite3'
else:
return 'MySQL'
Utils.error('Error:\n ccnet is using %s\n seafile is using %s\n seahub is using %s\n'
% (to_db_string(ccnet_db_info),
to_db_string(seafile_db_info),
to_db_string(seahub_db_info)))
return updater
def update_db(self):
ccnet_sql = os.path.join(self.sql_dir, 'ccnet.sql')
seafile_sql = os.path.join(self.sql_dir, 'seafile.sql')
seahub_sql = os.path.join(self.sql_dir, 'seahub.sql')
seafevents_sql = os.path.join(self.sql_dir, 'seafevents.sql')
if os.path.exists(ccnet_sql):
Utils.info('updating ccnet database...')
self.update_ccnet_sql(ccnet_sql)
if os.path.exists(seafile_sql):
Utils.info('updating seafile database...')
self.update_seafile_sql(seafile_sql)
if os.path.exists(seahub_sql):
Utils.info('updating seahub database...')
self.update_seahub_sql(seahub_sql)
if os.path.exists(seafevents_sql):
self.update_seafevents_sql(seafevents_sql)
@staticmethod
def get_ccnet_mysql_info(version):
host = os.environ.get('SEAFILE_MYSQL_DB_HOST', 'db')
port = os.environ.get('SEAFILE_MYSQL_DB_PORT', 3306)
username = os.environ.get('SEAFILE_MYSQL_DB_USER', 'seafile')
password = os.environ.get('SEAFILE_MYSQL_DB_PASSWORD')
db = os.environ.get('SEAFILE_MYSQL_DB_CCNET_DB_NAME', 'ccnet_db')
info = MySQLDBInfo(host, int(port), username, password, db)
return info
@staticmethod
def get_seafile_mysql_info(version):
host = os.environ.get('SEAFILE_MYSQL_DB_HOST', 'db')
port = os.environ.get('SEAFILE_MYSQL_DB_PORT', 3306)
username = os.environ.get('SEAFILE_MYSQL_DB_USER', 'seafile')
password = os.environ.get('SEAFILE_MYSQL_DB_PASSWORD')
db = os.environ.get('SEAFILE_MYSQL_DB_SEAFILE_DB_NAME', 'seafile_db')
info = MySQLDBInfo(host, int(port), username, password, db)
return info
@staticmethod
def get_seahub_mysql_info():
host = os.environ.get('SEAFILE_MYSQL_DB_HOST', 'db')
port = os.environ.get('SEAFILE_MYSQL_DB_PORT', 3306)
username = os.environ.get('SEAFILE_MYSQL_DB_USER', 'seafile')
password = os.environ.get('SEAFILE_MYSQL_DB_PASSWORD')
db = os.environ.get('SEAFILE_MYSQL_DB_SEAHUB_DB_NAME', 'seahub_db')
info = MySQLDBInfo(host, int(port), username, password, db)
return info
def update_ccnet_sql(self, ccnet_sql):
raise NotImplementedError
def update_seafile_sql(self, seafile_sql):
raise NotImplementedError
def update_seahub_sql(self, seahub_sql):
raise NotImplementedError
def update_seafevents_sql(self, seafevents_sql):
raise NotImplementedError
class CcnetSQLiteDB(object):
def __init__(self, ccnet_dir):
self.ccnet_dir = ccnet_dir
def get_db(self, dbname):
dbs = (
'ccnet.db',
'GroupMgr/groupmgr.db',
'misc/config.db',
'OrgMgr/orgmgr.db',
'PeerMgr/usermgr.db',
)
for db in dbs:
if os.path.splitext(os.path.basename(db))[0] == dbname:
return os.path.join(self.ccnet_dir, db)
class SQLiteDBUpdater(DBUpdater):
def __init__(self, version):
DBUpdater.__init__(self, version, 'sqlite3')
self.ccnet_db = CcnetSQLiteDB(env_mgr.ccnet_dir)
self.seafile_db = os.path.join(env_mgr.seafile_dir, 'seafile.db')
self.seahub_db = os.path.join(env_mgr.top_dir, 'seahub.db')
self.seafevents_db = os.path.join(env_mgr.seafevents_db_dir, 'seafevents.db')
def update_db(self):
super(SQLiteDBUpdater, self).update_db()
for sql_path in glob.glob(os.path.join(self.sql_dir, 'ccnet', '*.sql')):
self.update_ccnet_sql(sql_path)
def apply_sqls(self, db_path, sql_path):
with open(sql_path, 'r') as fp:
lines = fp.read().split(';')
with sqlite3.connect(db_path) as conn:
for line in lines:
line = line.strip()
if not line:
continue
else:
conn.execute(line)
def update_ccnet_sql(self, sql_path):
dbname = os.path.splitext(os.path.basename(sql_path))[0]
self.apply_sqls(self.ccnet_db.get_db(dbname), sql_path)
def update_seafile_sql(self, sql_path):
self.apply_sqls(self.seafile_db, sql_path)
def update_seahub_sql(self, sql_path):
self.apply_sqls(self.seahub_db, sql_path)
def update_seafevents_sql(self, sql_path):
if self.is_pro:
Utils.info('seafevents do not support sqlite3 database')
class MySQLDBUpdater(DBUpdater):
def __init__(self, version, ccnet_db_info, seafile_db_info, seahub_db_info):
DBUpdater.__init__(self, version, 'mysql')
self.ccnet_db_info = ccnet_db_info
self.seafile_db_info = seafile_db_info
self.seahub_db_info = seahub_db_info
def update_ccnet_sql(self, ccnet_sql):
self.apply_sqls(self.ccnet_db_info, ccnet_sql)
def update_seafile_sql(self, seafile_sql):
self.apply_sqls(self.seafile_db_info, seafile_sql)
def update_seahub_sql(self, seahub_sql):
self.apply_sqls(self.seahub_db_info, seahub_sql)
def update_seafevents_sql(self, seafevents_sql):
if self.is_pro:
Utils.info('updating seafevents database...')
self.apply_sqls(self.seahub_db_info, seafevents_sql)
def get_conn(self, info):
kw = dict(
user=info.username,
passwd=info.password,
db=info.db,
)
if info.unix_socket:
kw['unix_socket'] = info.unix_socket
else:
kw['host'] = info.host
kw['port'] = info.port
try:
conn = pymysql.connect(**kw)
except Exception as e:
if isinstance(e, pymysql.OperationalError):
msg = str(e.args[1])
else:
msg = str(e)
Utils.error('Failed to connect to mysql database %s: %s' % (info.db, msg))
return conn
def execute_sql(self, conn, sql):
cursor = conn.cursor()
try:
cursor.execute(sql)
conn.commit()
except Exception as e:
msg = str(e)
Utils.warning('Failed to execute sql: %s' % msg)
def apply_sqls(self, info, sql_path):
with open(sql_path, 'r') as fp:
lines = fp.read().split(';')
conn = self.get_conn(info)
for line in lines:
line = line.strip()
if not line:
continue
else:
self.execute_sql(conn, line)
def main():
skipdb = os.environ.get('SEAFILE_SKIP_DB_UPGRADE', '').lower()
if skipdb in ('1', 'true', 'on'):
print('Database upgrade skipped because SEAFILE_SKIP_DB_UPGRADE=%s' % skipdb)
sys.exit()
version = sys.argv[1]
db_updater = DBUpdater.get_instance(version)
db_updater.update_db()
return 0
if __name__ == '__main__':
main()