From 807e01a4bae5d1c49747bcb4ae69c98871bce9ff Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 5 Sep 2023 15:04:02 +0800 Subject: [PATCH] [zero] hotfix master param sync (#4618) * [zero] add method to update master params * [zero] update zero plugin * [plugin] update low level zero plugin --- .../booster/plugin/low_level_zero_plugin.py | 123 ++++++++++++------ colossalai/interface/__init__.py | 4 +- colossalai/interface/model.py | 11 ++ colossalai/zero/low_level/low_level_optim.py | 17 +++ .../test_low_level_zero_checkpoint_io.py | 12 ++ 5 files changed, 122 insertions(+), 45 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 6efafc56d..9adb4beec 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -3,6 +3,7 @@ import os import warnings from functools import partial from pathlib import Path +from types import MethodType from typing import Callable, Iterator, List, Optional, Tuple, Union import torch @@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import ( sharded_optimizer_loading_epilogue, unwrap_optimizer, ) -from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device -from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import LowLevelZeroOptimizer from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO @@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] +class LowLevelZeroModel(ModelWrapper, AMPModelMixin): + + def __init__(self, module: nn.Module, precision: str) -> None: + super().__init__(module) + self.dtype = None + if precision == 'fp16': + self.dtype = torch.float16 + elif precision == 'bf16': + self.dtype = torch.bfloat16 + if self.dtype is not None: + module = module.to(self.dtype) + module = module.to(get_current_device()) + self.module = module + self.convert_fn = None + if self.dtype is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) + + def forward(self, *args, **kwargs): + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) + return super().forward(*args, **kwargs) + + def unwrap(self): + # TODO(ver217): this is a workaround for loading model + return self + + class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): @@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): sharded_optimizer_loading_epilogue(optimizer) + def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, + use_safetensors: bool): + assert isinstance(model, LowLevelZeroModel) + super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors) -class LowLevelZeroModel(ModelWrapper): + def save_sharded_model(self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = True, + prefix: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False): + assert isinstance(model, LowLevelZeroModel) + super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, + use_safetensors) - def __init__(self, module: nn.Module, stage: int, precision: str) -> None: - super().__init__(module) - self.dtype = None - if precision == 'fp16': - self.dtype = torch.float16 - elif precision == 'bf16': - self.dtype = torch.bfloat16 - module = zero_model_wrapper(module, zero_stage=stage) - if self.dtype is not None: - module = module.to(self.dtype) - module = module.to(get_current_device()) - self.module = module - self.convert_fn = None - if self.dtype is not None: - self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) + def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True): + assert isinstance(model, LowLevelZeroModel) + super().load_unsharded_model(model.module, checkpoint, strict) + model.update_master_params() - def forward(self, *args, **kwargs): - if self.convert_fn is not None: - args = tree_map(self.convert_fn, args) - kwargs = tree_map(self.convert_fn, kwargs) - return super().forward(*args, **kwargs) + def load_sharded_model(self, + model: LowLevelZeroModel, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False, + load_sub_module: bool = True): + assert isinstance(model, LowLevelZeroModel) + super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module) + model.update_master_params() class LowLevelZeroPlugin(DPPluginBase): @@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase): super().__init__() assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' - + assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now' self.stage = stage self.precision = precision - self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, - communication_dtype=communication_dtype, - overlap_communication=overlap_communication, - cpu_offload=cpu_offload) - self.optim_kwargs = dict(initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type) + self.zero_optim_kwargs = dict( + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + clip_grad_norm=max_norm, + reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, + communication_dtype=communication_dtype, + overlap_communication=overlap_communication, + cpu_offload=cpu_offload, + partition_grad=(stage == 2), + ) self.verbose = verbose # set class name with stage, for better error message @@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase): ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: if not isinstance(model, ModelWrapper): - model = LowLevelZeroModel(model, self.stage, self.precision) + model = LowLevelZeroModel(model, self.precision) if optimizer is not None and \ not isinstance(optimizer, OptimizerWrapper): - optimizer = zero_optim_wrapper(model.unwrap(), - optimizer, - optim_config=self.zero_optim_config, - **self.optim_kwargs, - verbose=self.verbose) + optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer, + **self.zero_optim_kwargs, + verbose=self.verbose) + # inject update_master_params + model.update_master_params = MethodType(optimizer.update_master_params, model) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/interface/__init__.py b/colossalai/interface/__init__.py index 8c658e375..1c3199fc1 100644 --- a/colossalai/interface/__init__.py +++ b/colossalai/interface/__init__.py @@ -1,4 +1,4 @@ -from .model import ModelWrapper +from .model import AMPModelMixin, ModelWrapper from .optimizer import OptimizerWrapper -__all__ = ['OptimizerWrapper', 'ModelWrapper'] +__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin'] diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py index a067d7671..7b3d9435d 100644 --- a/colossalai/interface/model.py +++ b/colossalai/interface/model.py @@ -23,3 +23,14 @@ class ModelWrapper(nn.Module): def forward(self, *args, **kwargs): return self.module(*args, **kwargs) + + +class AMPModelMixin: + """This mixin class defines the interface for AMP training. + """ + + def update_master_params(self): + """ + Update the master parameters for AMP training. + """ + pass diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index b4439ab19..d9d6298d7 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple import torch import torch.distributed as dist +import torch.nn as nn from torch.distributed import ProcessGroup from torch.optim import Optimizer @@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ret_block_size += current_block_size yield ret_block, ret_block_size + + def update_master_params(self, model: nn.Module) -> None: + """Update master params from working params + + Args: + model (nn.Module): The model to update master params + """ + for p in model.parameters(): + p_id = id(p) + if p_id in self._param_store.working_to_master_param: + master_param = self._param_store.working_to_master_param[p_id] + padding_size = self._param_store.get_param_padding_size(p) + working_param = p.data.view(-1) + if padding_size > 0: + working_param = torch.nn.functional.pad(working_param, [0, padding_size]) + master_param.copy_(working_param.chunk(self._world_size)[self._local_rank]) diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index 3faa395b5..7ee733b26 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -14,6 +14,7 @@ from colossalai.testing import ( rerun_if_address_is_in_use, spawn, ) +from colossalai.zero import LowLevelZeroOptimizer # stage 1 and 2 process the optimizer/mode the same way @@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) + # check master weight + assert isinstance(new_optimizer, LowLevelZeroOptimizer) + working_param_id_set = set(id(p) for p in new_model.parameters()) + for p_id, master_param in new_optimizer._param_store.working_to_master_param.items(): + assert p_id in working_param_id_set + working_param = new_optimizer._param_store.master_to_working_param[id(master_param)] + padding = new_optimizer._param_store.get_param_padding_size(working_param) + padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding)) + working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()] + assert torch.equal(working_shard, + master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)