mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin * polish * polish
This commit is contained in:
@@ -14,10 +14,10 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
)
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils import conditional_context
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
@@ -56,7 +56,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
return False
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||
"""
|
||||
|
||||
@@ -77,11 +77,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
grad_accumulate_interval: int = 1,
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
assert not (partition_grad and grad_accumulate_interval > 1), \
|
||||
"gradient accumulation is not compatible with ZeRO-2"
|
||||
# TODO:
|
||||
# 1. process group api
|
||||
# 2. checkpoint IO
|
||||
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
@@ -94,8 +95,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
# grad accumulation
|
||||
self.require_grad_sync = True
|
||||
self._accumulate_intervel = grad_accumulate_interval
|
||||
self._accumulate_step = 0
|
||||
|
||||
colo_pg = self._search_colo_process_group()
|
||||
if isinstance(colo_pg, ProcessGroup):
|
||||
@@ -340,15 +339,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
|
||||
self._accumulate_step += 1
|
||||
no_sync = self._accumulate_step < self._accumulate_intervel
|
||||
with conditional_context(self.no_sync(), enable=no_sync):
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
if no_sync:
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
||||
self._reduce_grad(self._partition_grads)
|
||||
@@ -385,7 +384,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
if not self._accumulate_step == self._accumulate_intervel:
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||
@@ -393,7 +392,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
if self._verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self.zero_grad()
|
||||
self._accumulate_step -= 1
|
||||
return
|
||||
|
||||
# record all grads for unscale and clip
|
||||
@@ -463,9 +461,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
|
||||
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
# reset accumulate step
|
||||
self._accumulate_step = 0
|
||||
|
||||
#############################
|
||||
# Mixed Precision Utilities #
|
||||
#############################
|
||||
|
Reference in New Issue
Block a user