[plugin] support get_grad_norm (#6115)

This commit is contained in:
Hongxin Liu
2024-11-05 18:12:47 +08:00
committed by GitHub
parent 13ffa08cfa
commit a15ab139ad
8 changed files with 40 additions and 2 deletions

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import torch
from torch import Tensor, inf
@@ -84,6 +84,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group["params"] = master_params
self._current_grad_norm: Optional[float] = None
def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
@@ -187,6 +188,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
@@ -212,3 +214,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm