mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[utils] Impl clip_grad_norm for ColoTensor and ZeroOptimizer (#1442)
This commit is contained in:
@@ -8,9 +8,11 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.utils.common import _compute_grad_lp, compute_grad_norm, _clip_grad_norm
|
||||
from collections import defaultdict, abc as container_abcs
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from torch._six import inf
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
@@ -143,11 +145,38 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float):
|
||||
def compute_grad_norm(self, norm_type: float = 2.0) -> float:
|
||||
norm_type = float(norm_type)
|
||||
if not self.chunk_manager.enable_distributed_storage:
|
||||
return compute_grad_norm(self.module.parameters(), norm_type)
|
||||
|
||||
non_distributed_params = []
|
||||
distributed_params = []
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
non_distributed_params.append(p)
|
||||
else:
|
||||
distributed_params.append(p)
|
||||
non_distributed_norm = _compute_grad_lp(non_distributed_params, norm_type)
|
||||
distributed_norm_tensor = torch.tensor([_compute_grad_lp(distributed_params, norm_type)],
|
||||
device=get_current_device())
|
||||
if norm_type == inf:
|
||||
dist.all_reduce(distributed_norm_tensor,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=self.chunk_manager.process_group.dp_process_group())
|
||||
total_norm = max(non_distributed_norm, distributed_norm_tensor.item())
|
||||
else:
|
||||
dist.all_reduce(distributed_norm_tensor, group=self.chunk_manager.process_group.dp_process_group())
|
||||
total_norm = non_distributed_norm + distributed_norm_tensor.item()
|
||||
total_norm = total_norm**(1 / norm_type)
|
||||
return total_norm
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
# TODO(ver217): fix zero clip grad norm
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
total_norm = self.compute_grad_norm(norm_type)
|
||||
_clip_grad_norm(self.module.parameters(), max_norm, total_norm)
|
||||
return total_norm
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
|
Reference in New Issue
Block a user