diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index cee547b33..ec3dc7fc1 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import GeneralCheckpointIO -from colossalai.interface import ModelWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from .accelerator import Accelerator from .mixed_precision import MixedPrecision, mixed_precision_factory @@ -153,18 +153,20 @@ class Booster: # return loss or outputs if needed pass - def no_sync(self, model: nn.Module) -> contextmanager: + def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager: """Context manager to disable gradient synchronization across DP process groups. + Support torch DDP and Low Level ZeRO-1 for now. Args: - model (nn.Module): The model to be disabled gradient synchronization. + model (nn.Module): The model to be disabled gradient synchronization, for DDP + optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1 Returns: contextmanager: Context to disable gradient synchronization. """ assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' - assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' - return self.plugin.no_sync(model) + assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' + return self.plugin.no_sync(model, optimizer) def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True): """Load model from checkpoint. diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 7b6e17337..0f5ba6e9a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -408,5 +408,5 @@ class GeminiPlugin(DPPluginBase): def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 3ec0d3409..0a3221b23 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -179,8 +179,11 @@ class LowLevelZeroPlugin(DPPluginBase): norm_type=norm_type) self.verbose = verbose + # set class name with stage, for better error message + setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") + def support_no_sync(self) -> bool: - return False + return self.stage == 1 def control_precision(self) -> bool: return True @@ -219,5 +222,6 @@ class LowLevelZeroPlugin(DPPluginBase): def get_checkpoint_io(self) -> CheckpointIO: return LowLevelZeroCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: - raise NotImplementedError + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: + assert isinstance(optimizer, LowLevelZeroOptimizer) + return optimizer.optim.no_sync() diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index aa78f6827..fb21e57f4 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -61,7 +61,7 @@ class Plugin(ABC): pass @abstractmethod - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: """ Context manager to disable gradient synchronization. """ diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 71b435155..f3f779c88 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -168,6 +168,6 @@ class TorchDDPPlugin(DPPluginBase): def get_checkpoint_io(self) -> CheckpointIO: return TorchDDPCheckpointIO() - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.' return model.module.no_sync() diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index abfffa9b0..fb7b5baad 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -177,7 +177,7 @@ class TorchFSDPPlugin(DPPluginBase): def support_no_sync(self) -> bool: False - def no_sync(self, model: nn.Module) -> Iterator[None]: + def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: raise NotImplementedError("Torch fsdp no_sync func not supported yet.") def control_precision(self) -> bool: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8743cab33..615c87097 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -14,10 +14,10 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( ) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc +from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.utils import conditional_context from colossalai.utils.cuda import get_current_device from ._utils import ( @@ -56,7 +56,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): return False -class LowLevelZeroOptimizer(ColossalaiOptimizer): +class LowLevelZeroOptimizer(OptimizerWrapper): """Optimizer used for ZeRO-1 and ZeRO-2. """ @@ -77,11 +77,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload - grad_accumulate_interval: int = 1, forced_dtype: Optional[torch.dtype] = None): - assert not (partition_grad and grad_accumulate_interval > 1), \ - "gradient accumulation is not compatible with ZeRO-2" + # TODO: + # 1. process group api + # 2. checkpoint IO + super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() @@ -94,8 +95,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # grad accumulation self.require_grad_sync = True - self._accumulate_intervel = grad_accumulate_interval - self._accumulate_step = 0 colo_pg = self._search_colo_process_group() if isinstance(colo_pg, ProcessGroup): @@ -340,15 +339,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ################################ def backward(self, loss, retain_graph=False): + assert not(self._partition_grads and not self.require_grad_sync), \ + "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - self._accumulate_step += 1 - no_sync = self._accumulate_step < self._accumulate_intervel - with conditional_context(self.no_sync(), enable=no_sync): - loss.backward(retain_graph=retain_graph) + loss.backward(retain_graph=retain_graph) - if no_sync: + if not self.require_grad_sync: return self._reduce_grad(self._partition_grads) @@ -385,7 +384,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - if not self._accumulate_step == self._accumulate_intervel: + if not self.require_grad_sync: return if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): @@ -393,7 +392,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): if self._verbose: self._logger.info(f'Found overflow. Skip step') self.zero_grad() - self._accumulate_step -= 1 return # record all grads for unscale and clip @@ -463,9 +461,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] - # reset accumulate step - self._accumulate_step = 0 - ############################# # Mixed Precision Utilities # ############################# diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index ac1f677f9..a1d14f1d5 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -9,6 +9,7 @@ from torch.testing import assert_close import colossalai from colossalai.testing import spawn from colossalai.testing.random import seed_all +from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer @@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc(): overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, - grad_accumulate_interval=2, verbose=True) zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, - clip_grad_norm=1.0, - grad_accumulate_interval=2) + clip_grad_norm=1.0) # create data seed_all(2021 + local_rank) input_data1 = torch.randn(32, 128).cuda() @@ -59,8 +58,11 @@ def exam_zero_1_2_grad_acc(): assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.sum().float()) - zero2_optimizer.backward(zero2_output.sum().float()) + no_sync = number == 0 + with conditional_context(zero1_optimizer.no_sync(), no_sync): + zero1_optimizer.backward(zero1_output.sum().float()) + with conditional_context(zero2_optimizer.no_sync(), no_sync): + zero2_optimizer.backward(zero2_output.sum().float()) if check_flag: for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): @@ -101,8 +103,7 @@ def exam_zero_1_grad_acc(): zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, - clip_grad_norm=1.0, - grad_accumulate_interval=2) + clip_grad_norm=1.0) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -112,20 +113,15 @@ def exam_zero_1_grad_acc(): input_data2 = torch.randn(32, 128).cuda() def fwd_bwd_func(number, cur_data, check_flag): - # zero-dp forward - zero_output = zero_model(cur_data) - # torch-ddp forward + no_sync = number == 0 + # zero1 fwd and bwd + with conditional_context(zero_optimizer.no_sync(), no_sync): + zero_output = zero_model(cur_data) + zero_optimizer.backward(zero_output.sum().float()) - # zero-dp backward - zero_optimizer.backward(zero_output.sum().float()) - # torch-ddp backward - if number < 1: - with torch_model.no_sync(): - torch_output = torch_model(cur_data) - assert torch.equal(zero_output, torch_output) - torch_output.sum().backward() - else: + # torch-ddp fwd and bwd + with conditional_context(torch_model.no_sync(), no_sync): torch_output = torch_model(cur_data) assert torch.equal(zero_output, torch_output) torch_output.sum().backward() @@ -133,7 +129,6 @@ def exam_zero_1_grad_acc(): if check_flag: # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) assert torch.equal(p.grad, z1p.grad) fwd_bwd_func(0, input_data1, True)