diff --git a/scripts/migrate-repo.py b/scripts/migrate-repo.py index a4fd7328df..e1a4c9fdc5 100644 --- a/scripts/migrate-repo.py +++ b/scripts/migrate-repo.py @@ -4,6 +4,7 @@ import os import sys import logging import configparser +import json from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from migrate import ObjMigrateWorker @@ -29,7 +30,7 @@ def main(argv): else: migrate_repo(repo_id, orig_storage_id, dest_storage_id) -def parse_seafile_config(): +def parse_seafile_config(storage_id): env = os.environ seafile_conf = os.path.join(env['SEAFILE_CENTRAL_CONF_DIR'], 'seafile.conf') cp = configparser.ConfigParser() @@ -39,13 +40,33 @@ def parse_seafile_config(): user = cp.get('database', 'user') passwd = cp.get('database', 'password') db_name = cp.get('database', 'db_name') - return host, port, user, passwd, db_name -def get_repo_ids(): - host, port, user, passwd, db_name = parse_seafile_config() - url = 'mysql+pymysql://' + user + ':' + passwd + '@' + host + ':' + port + '/' + db_name - print(url) - sql = 'SELECT repo_id FROM Repo' + is_default = is_default_storage(cp, storage_id) + + return host, port, user, passwd, db_name, is_default + +def is_default_storage(cp, orig_storage_id): + json_file = cp.get('storage', 'storage_classes_file') + f = open(json_file) + json_cfg = json.load(f) + + is_default = False + + for bend in json_cfg: + storage_id = bend['storage_id'] + if storage_id == orig_storage_id: + if 'is_default' in bend: + is_default = bend['is_default'] + break + + return is_default + +def get_repo_ids_by_storage_id (url, storage_id = None): + if storage_id: + sql = 'SELECT repo_id FROM RepoStorageId WHERE storage_id=\"%s\"'%(storage_id) + else: + sql = 'SELECT repo_id FROM RepoStorageId' + try: engine = create_engine(url, echo=False) session = sessionmaker(engine)() @@ -53,8 +74,52 @@ def get_repo_ids(): except: return None else: - result = result_proxy.fetchall() - return result + results = result_proxy.fetchall() + + repo_ids = {} + for r in results: + try: + repo_id = r[0] + except: + continue + repo_ids[repo_id] = repo_id + return repo_ids + +def get_repo_ids(storage_id): + host, port, user, passwd, db_name, is_default = parse_seafile_config(storage_id) + url = 'mysql+pymysql://' + user + ':' + passwd + '@' + host + ':' + port + '/' + db_name + + if is_default: + all_repo_ids = get_repo_ids_by_storage_id (url) + storage_repo_ids = get_repo_ids_by_storage_id (url, storage_id) + + sql = 'SELECT repo_id FROM Repo' + + try: + engine = create_engine(url, echo=False) + session = sessionmaker(engine)() + result_proxy = session.execute(text(sql)) + except: + return None + else: + results = result_proxy.fetchall() + + ret_repo_ids = [] + for r in results: + try: + repo_id = r[0] + except: + continue + #If it's default storage, we should also return the repos which are not in the RepoStorageID table. + #Repo table is checked to preventing returning deleted repos. + if is_default: + if repo_id in storage_repo_ids or not repo_id in all_repo_ids: + ret_repo_ids.append(repo_id) + else: + if repo_id in storage_repo_ids: + ret_repo_ids.append(repo_id) + + return ret_repo_ids def migrate_repo(repo_id, orig_storage_id, dest_storage_id): api.set_repo_status (repo_id, REPO_STATUS_READ_ONLY) @@ -100,13 +165,9 @@ def migrate_repo(repo_id, orig_storage_id, dest_storage_id): logging.info('The process of migrating repo [%s] is over.\n', repo_id) def migrate_repos(orig_storage_id, dest_storage_id): - repo_ids = get_repo_ids() + repo_ids = get_repo_ids(orig_storage_id) for repo_id in repo_ids: - try: - repo_id = repo_id[0] - except: - continue api.set_repo_status (repo_id, REPO_STATUS_READ_ONLY) dtypes = ['commits', 'fs', 'blocks'] workers = []