[feature] Add clip_grad_norm for hybrid_parallel_plugin (#4837)

* Add clip_grad_norm for hibrid_parallel_plugin

* polish code

* add unittests

* Move tp to a higher-level optimizer interface.

* bug fix

* polish code
This commit is contained in:
littsk
2023-10-12 11:32:37 +08:00
committed by GitHub
parent df63564184
commit 83b52c56cd
8 changed files with 1158 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
from typing import Dict, List
from typing import Dict, List, Tuple
import torch
from torch import Tensor
from torch import Tensor, inf
from torch.nn import Module, Parameter
from torch.optim import Optimizer
@@ -68,8 +68,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
@@ -102,32 +100,65 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
return super().zero_grad(*args, **kwargs)
def _unscale_and_clip_grads(self, total_norm: float) -> None:
"""
Unscale and clip gradients before performing the optimization step.
Args:
total_norm (float): The computed total gradient norm.
Returns:
None
"""
div_scale = 1.0
# If mixed-precision training is used, get the gradient division scale from the mixed-precision handler.
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()
if self.max_norm > 0.0:
# norm is in fact norm*scale
# Calculate the scaling factor for gradient clipping
# The gradient norm is scaled by 'div_scale' and then clipped to 'max_norm'
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
# If the clip factor exceeds 1, adjust 'div_scale' accordingly to ensure clipping
if clip > 1:
div_scale = clip * div_scale
# Apply the scaling factor to gradients
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(1.0 / div_scale)
def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.0:
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
Args:
param_gradient_pairs (List[Tuple[Tensor]]): List of (parameter, gradient) pairs; gradients are used for norm calculation.
norm_type (int, optional): Type of the norm used (e.g., 2 for L2 norm). Defaults to 2.
Returns:
float: The total norm of the given gradients.
"""
if len(param_gradient_pairs) == 0:
return 0.0
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0:
return 0.0
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()
# gradients used for norm calculation.
gradients = [grad for param, grad in param_gradient_pairs]
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
else:
total_norm_exponentiated = 0.0
for grad in gradients:
total_norm_exponentiated += grad.data.double().norm(norm_type) ** norm_type
total_norm = total_norm_exponentiated ** (1.0 / norm_type)
return total_norm
def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
@@ -142,8 +173,22 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()
# gradient unscale and clip.
if self.max_norm <= 0:
# no need to compute gradient norm.
total_norm = 0.0
else:
# compute the total norm.
param_gradient_pairs = [
(self.master_to_working_map[p], p.grad)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups: