[plugin] support get_grad_norm (#6115)

This commit is contained in:
Hongxin Liu
2024-11-05 18:12:47 +08:00
committed by GitHub
parent 13ffa08cfa
commit a15ab139ad
8 changed files with 40 additions and 2 deletions

View File

@@ -293,6 +293,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
self._current_grad_norm: Optional[float] = None
super().__init__(optim)
def backward(self, loss: Tensor, *args, **kwargs):
@@ -364,6 +365,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
# Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm)
@@ -477,6 +479,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
def get_master_to_working_map(self):
return None
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(