mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[plugin] support get_grad_norm (#6115)
This commit is contained in:
@@ -293,6 +293,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
self.pp_pg = pp_process_group
|
||||
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
|
||||
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
|
||||
self._current_grad_norm: Optional[float] = None
|
||||
super().__init__(optim)
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
@@ -364,6 +365,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
|
||||
]
|
||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||
self._current_grad_norm = total_norm
|
||||
|
||||
# Clip the gradients to prevent exploding gradients.
|
||||
self._clip_grad_norm(total_norm)
|
||||
@@ -477,6 +479,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||
def get_master_to_working_map(self):
|
||||
return None
|
||||
|
||||
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||
return self._current_grad_norm
|
||||
|
||||
|
||||
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
def __init__(
|
||||
|
Reference in New Issue
Block a user