diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 0a3221b23..616b218b2 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,5 +1,8 @@ +import logging +import os import warnings from functools import partial +from pathlib import Path from typing import Callable, Iterator, List, Optional, Tuple, Union import torch @@ -10,10 +13,16 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map from torch.utils.data import DataLoader -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO +from colossalai.checkpoint_io.utils import ( + get_optimizer_base_filenames, + get_shard_filename, + save_param_groups, + save_state_dict, +) from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device -from colossalai.zero import zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO @@ -32,21 +41,104 @@ SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): - """ - Save optimizer to checkpoint but only on master process. - """ - # TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now - warnings.warn( - 'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.') - checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' - GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor) + def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): + """Save optimizer to checkpoint but only on master process. - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): - warnings.warn( - 'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.') - checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' - super().load_optimizer(optimizer, checkpoint) + Args: + optimizer (OptimizerWrapper): Optimizer to save state_dict + checkpoint (str): Path to save checkpoint + gather_dtensor (bool): Whether to gather_dtensor, not used + """ + + # the `state_dict` in LowLevelZeroOptimizer has communication + # if only the master rank collect state_dict and save, + # the communication on each rank would not match + state_dict = optimizer.state_dict() + if self.coordinator.is_master(): + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def save_sharded_optimizer(self, + optimizer: OptimizerWrapper, + checkpoint: str, + gather_dtensor: bool = False, + prefix: str = None, + size_per_shard: int = 1024): + """ + Save sharded Zero-optimizer checkpoint under the given checkpointing path. + The following files will be created under the path: + - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names + - A group file (pytorch_optim_group.bin) recording information of param_groups + - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way + + Args: + optimizer (OptimizerWrapper): Optimizer to save sharded state_dict + checkpoint (str): Path to save optimizer state_dict + gather_dtensor (bool): Whether to gather_dtensor, not used + prefix (str): Perfix of file to save + size_per_shard (int): Max file size of each file that store state tensors + """ + if os.path.isfile(checkpoint): + logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") + return + + Path(checkpoint).mkdir(parents=True, exist_ok=True) + + # state_dict only provide only 'param_groups' + state_dict = optimizer.optim.state_dict() + # state shard would be handled by the low-level zero optimizer + sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard) + + # Preparing file paths and index file. + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + index_file = CheckpointIndexFile(checkpoint) + + # Store the information of param groups to param_group_file. + index_file.append_meta_data("param_groups", param_group_file) + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) + + # Save shards of optimizer states. + total_size = 0 + for idx, shard_pair in enumerate(sharded_state): + shard, current_size = shard_pair + shard_file = get_shard_filename(states_name, idx) + total_size = total_size + current_size + for param_id in shard.keys(): + index_file.append_weight_map(str(param_id), shard_file) + + checkpoint_file_path = os.path.join(checkpoint, shard_file) + if self.coordinator.is_master(): + save_state_dict(shard, checkpoint_file_path, use_safetensors=False) + + # Wrap up index file. + index_file.append_meta_data("total_size", total_size) + if self.coordinator.is_master(): + index_file.write_index_file(save_index_file) + logging.info(f"The optimizer is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str): + """Load sharded optimizer with the given path to index file. + + Args: + optimizer (OptimizerWrapper): Optimizer to load state_dict + index_file_path (str): Path to the index file + prefix (str): Not used. + """ + super().load_sharded_optimizer(optimizer, index_file_path, prefix) + current_rank_state_dict = optimizer.optim.state_dict()['state'] + for param_idx, state in current_rank_state_dict.items(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != 'step': + padding_size = (self.coordinator.world_size - + v.numel() % self.coordinator.world_size) % self.coordinator.world_size + with torch.no_grad(): + v = v.flatten() + if padding_size > 0: + v = torch.nn.functional.pad(v, [0, padding_size]) + v_list = v.split(v.numel() // self.coordinator.world_size) + current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach() class LowLevelZeroModel(ModelWrapper): @@ -74,36 +166,6 @@ class LowLevelZeroModel(ModelWrapper): return super().forward(*args, **kwargs) -class LowLevelZeroOptimizer(OptimizerWrapper): - - def __init__(self, - module: nn.Module, - optimizer: Optimizer, - zero_optim_config: dict, - optim_kwargs: dict, - verbose: bool = False) -> None: - optimizer = zero_optim_wrapper(module, - optimizer, - optim_config=zero_optim_config, - **optim_kwargs, - verbose=verbose) - super().__init__(optimizer) - - def backward(self, loss: Tensor, *args, **kwargs): - self.optim.backward(loss) - - def clip_grad_by_norm(self, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2, - error_if_nonfinite: bool = False, - *args, - **kwargs) -> Tensor: - warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm') - - def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: - raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') - - class LowLevelZeroPlugin(DPPluginBase): """ Plugin for low level zero. @@ -211,8 +273,11 @@ class LowLevelZeroPlugin(DPPluginBase): if optimizer is not None and \ not isinstance(optimizer, OptimizerWrapper): - optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs, - self.verbose) + optimizer = zero_optim_wrapper(model.unwrap(), + optimizer, + optim_config=self.zero_optim_config, + **self.optim_kwargs, + verbose=self.verbose) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 72bec8b0c..023db122f 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -2,7 +2,7 @@ import copy from contextlib import contextmanager from functools import partial -from typing import Optional +from typing import Dict, Iterator, Optional, Tuple import torch import torch.distributed as dist @@ -447,18 +447,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # Gradient Synchronization # ############################ + # this method is used to sync gradient manually + def sync_grad(self): + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.requires_grad and param.grad is not None: + self._add_to_bucket(param, group_id) + + self._run_reduction() + def _reduce_grad(self, partition_grad): # if not overlapping communication (no reduction hook is attached) when zero1 # we need to manually reduce these gradients if not partition_grad and not self._overlap_communication: - for group_id in range(len(self._working_param_groups)): - param_group = self._working_param_groups[group_id] - for param in param_group: - if param.grad is not None: - self._add_to_bucket(param, group_id) - - # run reduction - self._run_reduction() + self.sync_grad() + else: + self._run_reduction() # this context comes from pytorch DDP @contextmanager @@ -473,7 +478,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ############## # State Dict # ############## - def _pack_state(self, state: dict) -> dict: + + def _pack_state(self, state: Dict) -> Dict: # comes from pytorch optimizer.state_dict() param_mappings = {} start_index = 0 @@ -487,17 +493,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper): start_index += len(packed['params']) return packed - param_groups = [pack_group(g) for g in self.param_groups] + param_groups = [pack_group(g) for g in self.optim.param_groups] # Remap state to use order indices as keys packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()} return {'state': packed_state, 'param_groups': param_groups} - def state_dict(self) -> dict: + def state_dict(self) -> Dict: """Return a state_dict same with DDP Returns: - dict: the pytorch form state_dict + Dict: the pytorch form state_dict """ zero_state = dict() for param, state in self.optim.state.items(): @@ -514,7 +520,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): return states_dict - def load_state_dict(self, state_dict: dict): + def load_state_dict(self, state_dict: Dict): """Load state dict, requires the state_dict be the pytorch form Args: @@ -534,3 +540,46 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self.optim.load_state_dict(zero_state_dict) zero_state_dict = dict() + + def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]: + """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. + Only include the 'state' in state_dict. + + Args: + max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024. + + Yields: + Iterator[OrderedDict]: A generator of state dict shard + """ + ret_block = dict() + ret_block_size = 0 + + local_states = self.optim.state_dict()['state'] + for param_idx, states in local_states.items(): + current_block_size = 0 + current_block = copy.deepcopy(states) + + # find the working param of current param_id + for group_id, pg in self._master_param_groups_of_current_rank.items(): + if (group_id + 1) * len(pg) < param_idx: + continue + master_param = pg[param_idx - (group_id) * len(pg)] + working_param = self._param_store.master_to_working_param[id(master_param)] + + for k, v in states.items(): + if isinstance(v, torch.Tensor) and k != 'step': + state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)] + dist.all_gather(state_tensor, v, group=self.dp_pg) + state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param) + current_block_size += state_tensor.numel() + current_block[k] = state_tensor + + if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0: + yield ret_block, ret_block_size + ret_block = dict() + ret_block_size = 0 + + ret_block[param_idx] = current_block + ret_block_size += current_block_size + + yield ret_block, ret_block_size diff --git a/colossalai/zero/low_level/readme.md b/colossalai/zero/low_level/readme.md new file mode 100644 index 000000000..aa92159d8 --- /dev/null +++ b/colossalai/zero/low_level/readme.md @@ -0,0 +1,54 @@ +# Low Level ZeRO +>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO. + +## Design: +### Notion +`p32` denotes the param copy in the optimizer +`p` denotes the model param +`g` denotes the gradient + +### INIT +In low level zero(1, 2), `p32` is split. Different from the previous implement, we split each `p32` evenly by world_size. Thus, rank0 got a param list as `[p00, p10]`, rank1 got a param list as `[p-01, p-11]`, etc. +image + +For the detailed implementation, we first pad `p` for it can be split by world_size if needed. Then, we would view it to the shape `[world_size, -1]`, and each rank got its own part `p32` by cloning. + +### BWD +To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each `g` in it would be reshaped as `[world_size, -1]`. And the `[local_rank]` parts would be united. +The data structure looks like this: +``` +{ +0: [g-00, g-10], +1: [g-01, g-11], +2: [g-02, g-12] +} +``` +After that, the gradients would be flattened by rank, and the data structure looks like this: +``` +# g-0 means flatten([g-00, g-10]) +{ +0: [g-0], +1: [g-1], +2: [g-2] +} +``` +For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`. + +### Optim +For each rank gets its own `p32` and the counterpart `g`, it is quite easy to do `optim.step()`. + +However, we have to consider a situation of layer drop, for instance: +``` +class MlpModel(nn.Module): + def __init__(self): + super(MlpModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.drop_linear = nn.Linear(256, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x +``` +And the solution is to build a mapping of `p32`, `p`, and `g`. Before `optim.step()`, we collect `p` which `requires_grad=True` and `p.grad != None` as a real working param. And select the counterpart `p32` and `g`. diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index c51b54c82..a94e8d42c 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -38,9 +38,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool): optimizer_ckpt_path = f"{tempdir}/optimizer" # lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here booster.save_model(model, model_ckpt_path, shard=shard) - if not shard: - # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint - booster.save_optimizer(optimizer, optimizer_ckpt_path) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + dist.barrier() new_model = resnet18() @@ -49,9 +48,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool): booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) - if not shard: - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) def run_dist(rank, world_size, port): @@ -62,3 +61,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_low_level_zero_checkpointIO(): spawn(run_dist, 2) + + +if __name__ == "__main__": + test_low_level_zero_checkpointIO()