diff --git a/scripts/upgrade/db_update_helper.py b/scripts/upgrade/db_update_helper.py index d9247bc..45b4b46 100644 --- a/scripts/upgrade/db_update_helper.py +++ b/scripts/upgrade/db_update_helper.py @@ -26,8 +26,10 @@ class EnvManager(object): self.seafile_dir = os.environ['SEAFILE_CONF_DIR'] self.central_config_dir = os.environ.get('SEAFILE_CENTRAL_CONF_DIR') + env_mgr = EnvManager() + class Utils(object): @staticmethod def highlight(content, is_error=False): @@ -69,6 +71,8 @@ class MySQLDBInfo(object): class DBUpdater(object): def __init__(self, version, name): self.sql_dir = os.path.join(env_mgr.upgrade_dir, 'sql', version, name) + pro_path = os.path.join(env_mgr.install_path, 'pro') + self.is_pro = os.path.exists(pro_path) @staticmethod def get_instance(version): @@ -263,9 +267,6 @@ class SQLiteDBUpdater(DBUpdater): self.seahub_db = os.path.join(env_mgr.top_dir, 'seahub.db') self.seafevents_db = os.path.join(env_mgr.top_dir, 'seafevents.db') - def is_pro(self): - return os.path.exists(self.seafevents_db) - def update_db(self): super(SQLiteDBUpdater, self).update_db() for sql_path in glob.glob(os.path.join(self.sql_dir, 'ccnet', '*.sql')): @@ -294,9 +295,8 @@ class SQLiteDBUpdater(DBUpdater): self.apply_sqls(self.seahub_db, sql_path) def update_seafevents_sql(self, sql_path): - if self.is_pro(): - Utils.info('updating seafevents database...') - self.apply_sqls(self.seafevents_db, sql_path) + if self.is_pro: + Utils.info('seafevents do not support sqlite3 database') class MySQLDBUpdater(DBUpdater): @@ -316,9 +316,9 @@ class MySQLDBUpdater(DBUpdater): self.apply_sqls(self.seahub_db_info, seahub_sql) def update_seafevents_sql(self, seafevents_sql): - if self.is_pro(self.seahub_db_info): + if self.is_pro: Utils.info('updating seafevents database...') - self.apply_sqls(self.seahub_db_info, seahub_sql) + self.apply_sqls(self.seahub_db_info, seafevents_sql) def get_conn(self, info): kw = dict( @@ -368,14 +368,6 @@ class MySQLDBUpdater(DBUpdater): else: self.execute_sql(conn, line) - def is_pro(self, info): - conn = self.get_conn(info) - cursor = conn.cursor() - text = "select count(1) from information_schema.tables where table_schema=%s and table_name=%s" - cursor.execute(text, (info.db, 'Event')) - res = cursor.fetchone() - return res[0] == 1 - def main(): skipdb = os.environ.get('SEAFILE_SKIP_DB_UPGRADE', '').lower()