[zero] add L2 gradient clipping for ZeRO (#2112)

* [zero] add L2 gradient clipping

* [testing] add MlpModel

* [zero] add unit test for grad clipping

* fix atol
This commit is contained in:
HELSON
2022-12-09 18:09:17 +08:00
committed by GitHub
parent 70a8556946
commit 63fbba3c19
5 changed files with 194 additions and 11 deletions

View File

@@ -51,7 +51,6 @@ def alloc_storage(tensor: torch.Tensor) -> None:
class Chunk:
_total_number = 0
def __init__(self,
@@ -140,6 +139,10 @@ class Chunk:
# if the cpu_shard has been visited during the training step, the flag is True
self.cpu_vis_flag = False
# whether to record l2 norm for the gradient clipping calculation
self.l2_norm_flag = False
self.l2_norm = None
@property
def memory_usage(self) -> Dict[str, int]:
cuda_memory = 0
@@ -213,16 +216,28 @@ class Chunk:
@property
def has_inf_or_nan(self) -> bool:
"""Check if the chunk has inf or nan values in CUDA.
"""Check if the chunk has inf or nan values on CUDA.
"""
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
else:
assert self.cuda_shard is not None # only check in CUDA
assert self.cuda_shard is not None # only check on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
def set_l2_norm(self) -> None:
"""Record l2 norm of this chunks on CUDA.
"""
assert self.l2_norm is None, "you are calculating the l2 norm twice"
if self.is_gathered:
valid_tensor = self.chunk_total[:self.utilized_size]
else:
assert self.cuda_shard is not None # calculate on CUDA
valid_tensor = self.cuda_shard[:self.valid_end]
chunk_l2_norm = valid_tensor.data.float().norm(2)
self.l2_norm = chunk_l2_norm.item()**2
def append_tensor(self, tensor: torch.Tensor):
"""Add a tensor to the chunk.

View File

@@ -1,3 +1,4 @@
import math
from enum import Enum
from typing import Any, Dict, Set, Tuple
@@ -56,6 +57,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
clipping_norm: float = 0.0,
norm_type: float = 2.0,
**defaults: Any):
super().__init__(optim)
assert isinstance(module, ZeroDDP)
@@ -66,11 +69,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = clipping_norm > 0.0
self.max_norm = clipping_norm
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
self.chunk16_set.add(chunk_16)
self.__init__optimizer()
@@ -128,12 +137,45 @@ class ZeroOptimizer(ColossalaiOptimizer):
return self._found_overflow.item() > 0
def _unscale_grads(self):
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
for c16 in self.chunk16_set:
assert c16.l2_norm is not None
if c16.is_gathered:
norm_sqr += c16.l2_norm
else:
# this chunk is sharded, use communication to collect total norm
if c16.torch_pg not in group_to_norm:
group_to_norm[c16.torch_pg] = 0.0
group_to_norm[c16.torch_pg] += c16.l2_norm
c16.l2_norm = None # clear l2 norm
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
for group, part_norm in group_to_norm.items():
comm_buffer.fill_(part_norm)
dist.all_reduce(comm_buffer, group=group)
norm_sqr += comm_buffer.item()
global_norm = math.sqrt(norm_sqr)
return global_norm
def _unscale_and_clip_grads(self):
assert self.optim_state == OptimState.SCALED
combined_scale = self.loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / self.loss_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * self.loss_scale
for group in self.optim.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
p.grad.data.div_(combined_scale)
self.optim_state = OptimState.UNSCALED
@property
@@ -147,16 +189,21 @@ class ZeroOptimizer(ColossalaiOptimizer):
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
self._set_grad_ptr()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._unscale_and_clip_grads()
self.grad_scaler.update(found_inf)
ret = self.optim.step(*args, **kwargs)
self._register_states()
self.zero_grad()

View File

@@ -302,7 +302,11 @@ class ZeroDDP(ColoDDP):
chunk.chunk_total.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad