[zero]support no_sync method for zero1 plugin (#4138)

* support no sync for zero1 plugin

* polish

* polish
This commit is contained in:
LuGY
2023-07-04 12:00:33 +08:00
committed by Hongxin Liu
parent c6ab96983a
commit 79cf1b5f33
8 changed files with 45 additions and 49 deletions

View File

@@ -14,10 +14,10 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils import conditional_context
from colossalai.utils.cuda import get_current_device
from ._utils import (
@@ -56,7 +56,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
return False
class LowLevelZeroOptimizer(ColossalaiOptimizer):
class LowLevelZeroOptimizer(OptimizerWrapper):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
@@ -77,11 +77,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
grad_accumulate_interval: int = 1,
forced_dtype: Optional[torch.dtype] = None):
assert not (partition_grad and grad_accumulate_interval > 1), \
"gradient accumulation is not compatible with ZeRO-2"
# TODO:
# 1. process group api
# 2. checkpoint IO
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
@@ -94,8 +95,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# grad accumulation
self.require_grad_sync = True
self._accumulate_intervel = grad_accumulate_interval
self._accumulate_step = 0
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
@@ -340,15 +339,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
################################
def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
self._accumulate_step += 1
no_sync = self._accumulate_step < self._accumulate_intervel
with conditional_context(self.no_sync(), enable=no_sync):
loss.backward(retain_graph=retain_graph)
loss.backward(retain_graph=retain_graph)
if no_sync:
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
@@ -385,7 +384,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
if not self._accumulate_step == self._accumulate_intervel:
if not self.require_grad_sync:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
@@ -393,7 +392,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
self._accumulate_step -= 1
return
# record all grads for unscale and clip
@@ -463,9 +461,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
# reset accumulate step
self._accumulate_step = 0
#############################
# Mixed Precision Utilities #
#############################