[utils] Impl clip_grad_norm for ColoTensor and ZeroOptimizer (#1442)

This commit is contained in:
ver217
2022-08-11 22:58:58 +08:00
committed by GitHub
parent 74bee5f7e8
commit 821c6172e2
3 changed files with 232 additions and 5 deletions

View File

@@ -8,9 +8,11 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable
from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm
from collections import defaultdict, abc as container_abcs
from copy import deepcopy
from itertools import chain
from torch._six import inf
class OptimState(Enum):
@@ -143,11 +145,38 @@ class ZeroOptimizer(ColossalaiOptimizer):
self._update_fp16_params()
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float):
def compute_grad_norm(self, norm_type: float = 2.0) -> float:
norm_type = float(norm_type)
if not self.chunk_manager.enable_distributed_storage:
return compute_grad_norm(self.module.parameters(), norm_type)
non_distributed_params = []
distributed_params = []
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
non_distributed_params.append(p)
else:
distributed_params.append(p)
non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type)
distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)],
device=get_current_device())
if norm_type == inf:
dist.all_reduce(distributed_norm_tensor,
op=dist.ReduceOp.MAX,
group=self.chunk_manager.process_group.dp_process_group())
total_norm = max(non_distributed_norm, distributed_norm_tensor.item())
else:
dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group())
total_norm = non_distributed_norm + distributed_norm_tensor.item()
total_norm = total_norm**(1 / norm_type)
return total_norm
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
# TODO(ver217): fix zero clip grad norm
return super().clip_grad_norm(model, max_norm)
total_norm = self.compute_grad_norm(norm_type)
_clip_grad_norm(self.module.parameters(), max_norm, total_norm)
return total_norm
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss