From acae68eb0492604783088d01538819d5de117566 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Fri, 1 Apr 2022 16:49:21 +0800 Subject: [PATCH] [model checkpoint] updated checkpoint save/load utils (#592) --- colossalai/utils/__init__.py | 9 +- colossalai/utils/checkpointing.py | 385 +++++++++++++++++------------- 2 files changed, 218 insertions(+), 176 deletions(-) diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 3881d1eb7..f470671fc 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,9 +1,9 @@ from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .activation_checkpoint import checkpoint - +from .checkpointing import load_checkpoint, save_checkpoint from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, - free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0, - is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, + ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, + is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param) from .data_sampler import DataParallelSampler, get_dataloader @@ -18,5 +18,6 @@ __all__ = [ 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', - 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector' + 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', + 'ensure_path_exists' ] diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py index e822a8bfd..bc49656bc 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -1,212 +1,253 @@ -import os -import os.path as osp -import re -from typing import Tuple -from pathlib import Path +from collections import OrderedDict +from itertools import chain import torch - -from colossalai.context import Config +import torch.distributed as dist +from colossalai.communication.collective import scatter_object_list from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX -__all__ = [ - 'get_checkpoint_path', 'get_latest_checkpoint_path', 'get_latest_checkpoint_pattern', 'save_checkpoint', - 'load_checkpoint' -] +from .common import is_using_pp + +__all__ = ["save_checkpoint", "load_checkpoint"] -def unwrap_config(config: Config): - """Unwrap Config objects to normal dicts - """ - config_dict = dict() - for k, v in config.items(): - if isinstance(v, dict): - config_dict[k] = unwrap_config(v) - else: - config_dict[k] = v - - return config_dict +def broadcast_state_dict(state_dict, parallel_mode): + state_dict = [state_dict.copy() if isinstance(state_dict, dict) else state_dict] + src_rank = gpc.get_ranks_in_group(parallel_mode)[0] + dist.broadcast_object_list(state_dict, src=src_rank, group=gpc.get_cpu_group(parallel_mode)) + return state_dict[0] -def _get_ranks_name(): - # tensor parallel - tp_local_rank = 0 - if gpc.is_initialized(ParallelMode.TENSOR): - tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR) +def partition_tensor_parallel_state_dict( + state_dict: OrderedDict, parallel_mode: ParallelMode, dims: dict = dict(), partition_states: dict = dict() +): + src_rank = gpc.get_ranks_in_group(parallel_mode)[0] + depth = gpc.get_world_size(parallel_mode) - # pipeline parallel - pp_local_rank = 0 - if gpc.is_initialized(ParallelMode.PIPELINE): - pp_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + if gpc.get_local_rank(parallel_mode) == 0: - ranks_name = f'tp{tp_local_rank}-pp{pp_local_rank}' - return ranks_name + partitioned_state_list = [dict() for _ in range(depth)] + for key in list(state_dict.keys()): + param = state_dict.pop(key) + dim = dims.get(key, 0) + do_partition = partition_states.get(key, True) + if do_partition: + param = torch.chunk(param, depth, dim=dim) + for i, p in enumerate(partitioned_state_list): + p[key] = param[i] if do_partition else param -def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''): - ranks_name = _get_ranks_name() - return f'epoch{epoch}-{ranks_name}{suffix}.pt' - - -def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''): - """This is a function to generate the checkpoint path from the tuple - (checkpoint_dir, epoch, suffix, gpu_parallel_rank). - This is useful during generation and recuperation of the checkpoint. - - Args: - checkpoint_dir (str): Set up a directory for saving checkpoints. - epoch (int): Epoch number (indicate how many epochs have you trained this model). - suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to '' - - Returns: - str: The checkpoint path to be generated. - """ - ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix) - return os.path.join(checkpoint_dir, ckpt_filename) - - -def _ensure_directory_exists(filename: str): - # ensure the directory exists - dirpath = os.path.dirname(filename) - if not os.path.exists(dirpath): - Path(dirpath).mkdir(parents=True, exist_ok=True) - - -def get_latest_checkpoint_pattern(suffix: str = ''): - """Generate Regular expression of the latest checkpoint's pattern. - - Args: - suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to ''. - - Returns: - str: The regular expression of checkpoint pattern. - """ - ranks_name = _get_ranks_name() - pattern = r'epoch(\d+)-{}{}\.pt'.format(ranks_name, suffix) - ckpt_pattern = re.compile(pattern) - return ckpt_pattern - - -def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''): - """This is a function to retrieve the latest checkpoint path from the tuple - (checkpoint_dir, suffix, gpu_parallel_rank). - This is useful during recuperation of the checkpoint, especially when you do not know the epoch number. - - Args: - checkpoint_dir (str): Directory for saving checkpoints - suffix (str, optional): Additional notation to specify the model or checkpoint, defaults to '' - - Returns: - str: The latest retrieved checkpoint path. - - Raises: - FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given. - """ - CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix) - - last_epoch = -1 - assert osp.isdir(checkpoint_dir), f'{checkpoint_dir} is not a directory' - - for filename in os.listdir(checkpoint_dir): - ret = CKPT_NAME_PAT.match(filename) - if ret: - epoch = int(ret[0].split('-')[0].lstrip('epoch')) - if epoch > last_epoch: - last_epoch = epoch - - if last_epoch == -1: - ranks_name = _get_ranks_name() - raise FileNotFoundError(f"Cannot find the latest checkpoint file for {ranks_name} in {checkpoint_dir}") else: - target_file = _get_standard_checkpoint_filename(last_epoch, suffix=suffix) - path = osp.join(checkpoint_dir, target_file) - return path + partitioned_state_list = [None for _ in range(depth)] + + partitioned_state = [None] + scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode)) + return partitioned_state[0] -def save_checkpoint(checkpoint_path: str, - epoch: int, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - **kwargs): - """Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as - model, optimizer, lr_scheduler etc. into a checkpoint dictionary. +def gather_tensor_parallel_state_dict( + state_dict: OrderedDict, + parallel_mode: ParallelMode, + dims: dict = dict(), + partition_states: dict = dict(), + keep_vars: bool = False, +): + dst_rank = gpc.get_ranks_in_group(parallel_mode)[0] + depth = gpc.get_world_size(parallel_mode) - This method can be used for both :class:`colossalai.nn.BaseModel` and normal :class:`torch.nn.Module`. + for key in list(state_dict.keys()): + param = state_dict.pop(key) + param = param if keep_vars else param.detach() + dim = dims.get(key, 0) + do_partition = partition_states.get(key, True) + if do_partition: + temp = param.transpose(0, dim).contiguous() + gather_list = None + if gpc.get_local_rank(parallel_mode) == 0: + shape = list(param.shape) + shape[0], shape[dim] = shape[dim], shape[0] + shape[0] *= depth + param = torch.empty(shape, dtype=param.dtype, device=param.device) + gather_list = list(torch.chunk(param, depth, dim=0)) + dist.gather(temp, gather_list, dst=dst_rank, group=gpc.get_cpu_group(parallel_mode)) + param = torch.transpose(param, 0, dim) + # update params in state_dict only on local rank 0 + if gpc.get_local_rank(parallel_mode) == 0: + state_dict[key] = param + + return state_dict + + +def _send_state_dict(state_dict, dst, parallel_mode): + state_tensor, state_size = dist.distributed_c10d._object_to_tensor(state_dict) + dist.send(state_size, dst, group=gpc.get_cpu_group(parallel_mode)) + dist.send(state_tensor, dst, group=gpc.get_cpu_group(parallel_mode)) + + +def _recv_state_dict(src, parallel_mode): + state_size = torch.tensor([0], dtype=torch.long) + dist.recv(state_size, src, group=gpc.get_cpu_group(parallel_mode)) + state_tensor = torch.empty(state_size.item(), dtype=torch.uint8) + dist.recv(state_tensor, src, group=gpc.get_cpu_group(parallel_mode)) + state_dict = dist.distributed_c10d._tensor_to_object(state_tensor, state_size) + return state_dict + + +def partition_pipeline_parallel_state_dict(model, state_dict): + pipeline_state = OrderedDict() + + if gpc.get_local_rank(ParallelMode.TENSOR) == 0: + # receive all states from prev stage + if not gpc.is_first_rank(ParallelMode.PIPELINE): + state_dict = _recv_state_dict(gpc.get_prev_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE) + # move states to output + for name, _ in model.named_parameters(recurse=True): + if name in state_dict: + pipeline_state[name] = state_dict.pop(name) + for name, _ in model.named_buffers(recurse=True): + if name in state_dict: + pipeline_state[name] = state_dict.pop(name) + for name, _ in model.named_modules(): + extra_state_key = name + "." + _EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + pipeline_state[extra_state_key] = state_dict.pop(extra_state_key) + # send rest states to next stage + if not gpc.is_last_rank(ParallelMode.PIPELINE): + _send_state_dict(state_dict, gpc.get_next_global_rank(ParallelMode.PIPELINE), ParallelMode.PIPELINE) + + return pipeline_state + + +def gather_pipeline_parallel_state_dict(state_dict): + gathered_states = ( + [None for _ in range(gpc.get_world_size(ParallelMode.PIPELINE))] + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + else None + ) + dist.gather_object( + state_dict, + gathered_states, + dst=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[0], + group=gpc.get_cpu_group(ParallelMode.PIPELINE), + ) + + state_dict = ( + OrderedDict(chain.from_iterable(state.items() for state in gathered_states)) + if gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + else OrderedDict() + ) + + return state_dict + + +def save_checkpoint( + file, + epoch: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + **kwargs +): + """Stores the checkpoint to disk. Saves all the training components' parameters or buffers, such as model, optimizer, + lr_scheduler etc. into a checkpoint dictionary. Args: - checkpoint_path (str): Set up a directory for saving checkpoints. - epoch (int): Epoch number (indicate how many epochs have you trained this model). - model (:class:`torch.nn.Module`): Model to be registered. - optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be registered. + file: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a + file name. + epoch (int): Epoch number (indicates how many epochs have you trained this model). + model (:class:`torch.nn.Module`): Model to be saved. + optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to be saved. lr_scheduler (Union[:class:`torch.optim.lr_scheduler`, - :class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be registered, defaults to None. - kwargs (dict): additional parameters to be saved. + :class:`colossalai.nn.lr_scheduler`], optional): lr_scheduler to be saved, defaults to None. + pickle_module: module used for pickling metadata and objects + pickle_protocol: can be specified to override the default protocol """ - # for compatibility with normal pytorch nn.Module - if hasattr(model, 'state_dict_for_save_checkpoint'): - model_sd = model.state_dict_for_save_checkpoint() - else: - model_sd = model.state_dict() - # ckpt container - checkpoint = {'epoch': epoch, 'model': model_sd, 'optimizer': optimizer.state_dict(), **kwargs} - if lr_scheduler is not None: - checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + checkpoint = {"epoch": epoch} - _ensure_directory_exists(checkpoint_path) - torch.save(checkpoint, checkpoint_path) + model_state = model.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + model_state = gather_pipeline_parallel_state_dict(model_state) + + if gpc.get_global_rank() == 0: + checkpoint["model"] = model_state + + # if optimizer is not None: + # checkpoint['optimizer'] = optimizer.state_dict() + + # if lr_scheduler is not None: + # checkpoint['lr_scheduler'] = lr_scheduler.state_dict() + + torch.save(checkpoint, file, **kwargs) -def load_checkpoint(checkpoint_path: str, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, - finetune: bool = False, - strict: bool = True) -> Tuple: - """Loads the checkpoint file. +def load_checkpoint( + file, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + strict: bool = True, +): + """Loads training states from a checkpoint file. - If finetune is False, then we intend to continue/resume the training process from the checkpoint given. - So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) - and its descendants. - - If finetune is True, then only the weights and buffers of model should be reloaded. - If strict is True, then the keys of state_dict must exactly match the keys returned - by this module’s state_dict() function. - - Args: - checkpoint_path (str): The exact and matched checkpoint_path directory to retrieve appropriate state_dict. - model (:class:`torch.nn.Module`): Model to reload parameters and buffers. + Args: + file: a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or os.PathLike + object containing a file name. + model (:class:`torch.nn.Module`): Model to load saved weights and buffers. optimizer (Union[:class:`torch.optim.Optimizer`, :class:`colossalai.nn.optimizer`]): Optimizer to recuperate. lr_scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`, optional): lr_scheduler to recuperate, defaults to None. - finetune (bool, optional): Whether to finetune the model with new dataset or - continue the pre-training, defaults to False. strict (bool, optional): Whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of parameters and buffers in model, defaults to True. Returns: - Tuple(int, ``checkpoint``): The tuple (the epoch number of the checkpoint retrieved, the checkpoint retrieved). + int: The saved epoch number. Raises: - ValueError: Raise error if the model/optimizer cannot successfully be recuperated + RuntimeError: Raise error if the model/optimizer cannot successfully be recuperated """ - # Load the checkpoint. - checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict = ( + torch.load(file, map_location=torch.device("cpu")) if gpc.get_local_rank(ParallelMode.MODEL) == 0 else None + ) + + # model states + model_state = state_dict.pop("model") if state_dict is not None else dict() + # pipeline + if is_using_pp(): + model_state = partition_pipeline_parallel_state_dict(model, model_state) try: - last_epoch = checkpoint.pop('epoch') if not finetune else 0 - model.load_state_dict(checkpoint.pop('model'), strict=strict) - except KeyError: - raise ValueError('Checkpoint is corrupted') + model.load_state_dict(model_state, strict=strict) + except RuntimeError as e: + error_msgs = str(e) + if error_msgs.startswith("Error(s) in loading state_dict for "): + error_msgs = error_msgs.split("\n\t")[1:] + dst_rank = gpc.get_ranks_in_group(ParallelMode.MODEL)[0] + all_error_msgs = [None for _ in range(gpc.get_world_size(ParallelMode.MODEL))] + dist.gather_object(error_msgs, all_error_msgs, dst=dst_rank, group=gpc.get_cpu_group(ParallelMode.MODEL)) + if gpc.get_global_rank() == 0: + all_error_msgs = list(chain.from_iterable(all_error_msgs)) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(all_error_msgs) + ) + ) + else: + raise e - if not finetune: - try: - optimizer.load_state_dict(checkpoint.pop('optimizer')) - except KeyError: - raise ValueError('Checkpoint is corrupted') + # broadcast the rest states + state_dict = broadcast_state_dict(state_dict, ParallelMode.MODEL) - if lr_scheduler is not None and 'lr_scheduler' in checkpoint: - lr_scheduler.load_state_dict(checkpoint.pop('lr_scheduler')) + # # optimizer states + # if optimizer is not None and 'optimizer' in state_dict: + # optimizer.load_state_dict(state_dict['optimizer']) - return last_epoch, checkpoint + # # lr scheduler states + # if lr_scheduler is not None and 'lr_scheduler' in state_dict: + # lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + + # last epoch + last_epoch = state_dict.pop("epoch", -1) + + return last_epoch