mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)
* Add clip_grad_norm for hibrid_parallel_plugin * polish code * add unittests * Move tp to a higher-level optimizer interface. * bug fix * polish code
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import Tensor, inf
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@@ -68,8 +68,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
self.mixed_precision = BF16MixedPrecisionMixin()
|
||||
else:
|
||||
raise ValueError(f"Unsupported precision: {precision}")
|
||||
if max_norm > 0.0:
|
||||
raise NotImplementedError("max_norm is not supported yet.")
|
||||
self.max_norm = max_norm
|
||||
self.working_to_master_map: Dict[Parameter, Tensor] = {}
|
||||
self.master_to_working_map: Dict[Tensor, Parameter] = {}
|
||||
@@ -102,32 +100,65 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
return super().zero_grad(*args, **kwargs)
|
||||
|
||||
def _unscale_and_clip_grads(self, total_norm: float) -> None:
|
||||
"""
|
||||
Unscale and clip gradients before performing the optimization step.
|
||||
|
||||
Args:
|
||||
total_norm (float): The computed total gradient norm.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
div_scale = 1.0
|
||||
|
||||
# If mixed-precision training is used, get the gradient division scale from the mixed-precision handler.
|
||||
if self.mixed_precision is not None:
|
||||
div_scale = self.mixed_precision.get_grad_div_scale()
|
||||
|
||||
if self.max_norm > 0.0:
|
||||
# norm is in fact norm*scale
|
||||
# Calculate the scaling factor for gradient clipping
|
||||
# The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm'
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||
|
||||
# If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping
|
||||
if clip > 1:
|
||||
div_scale = clip * div_scale
|
||||
|
||||
# Apply the scaling factor to gradients
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
p.grad.data.mul_(1.0 / div_scale)
|
||||
|
||||
def _compute_grad_norm(self) -> float:
|
||||
if self.max_norm <= 0.0:
|
||||
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
||||
Args:
|
||||
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
|
||||
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
|
||||
|
||||
Returns:
|
||||
float: The total norm of the given gradients.
|
||||
"""
|
||||
|
||||
if len(param_gradient_pairs) == 0:
|
||||
return 0.0
|
||||
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
|
||||
if len(grads) == 0:
|
||||
return 0.0
|
||||
device = grads[0].device
|
||||
# TODO(ver217): support tp
|
||||
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
|
||||
return total_norm.item()
|
||||
|
||||
# gradients used for norm calculation.
|
||||
gradients = [grad for param, grad in param_gradient_pairs]
|
||||
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
|
||||
else:
|
||||
total_norm_exponentiated = 0.0
|
||||
for grad in gradients:
|
||||
total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type
|
||||
total_norm = total_norm_exponentiated ** (1.0 / norm_type)
|
||||
|
||||
return total_norm
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
if self.mixed_precision.should_skip_step():
|
||||
@@ -142,8 +173,22 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
if working_param.grad is not None:
|
||||
p.grad = working_param.grad.data.float()
|
||||
working_param.grad = None
|
||||
total_norm = self._compute_grad_norm()
|
||||
|
||||
# gradient unscale and clip.
|
||||
if self.max_norm <= 0:
|
||||
# no need to compute gradient norm.
|
||||
total_norm = 0.0
|
||||
else:
|
||||
# compute the total norm.
|
||||
param_gradient_pairs = [
|
||||
(self.master_to_working_map[p], p.grad)
|
||||
for group in self.param_groups
|
||||
for p in group["params"]
|
||||
if p.grad is not None
|
||||
]
|
||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||
self._unscale_and_clip_grads(total_norm)
|
||||
|
||||
self.optim.step(*args, **kwargs)
|
||||
# update working params
|
||||
for group in self.optim.param_groups:
|
||||
|
Reference in New Issue
Block a user