mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[plugin] support get_grad_norm (#6115)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, inf
|
||||
@@ -84,6 +84,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
self.master_to_working_map[master_p] = p
|
||||
master_params.append(master_p)
|
||||
group["params"] = master_params
|
||||
self._current_grad_norm: Optional[float] = None
|
||||
|
||||
def backward(self, loss: Tensor, *args, **kwargs):
|
||||
loss = self.mixed_precision.pre_backward(loss)
|
||||
@@ -187,6 +188,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
if p.grad is not None
|
||||
]
|
||||
total_norm = self._compute_grad_norm(param_gradient_pairs)
|
||||
self._current_grad_norm = total_norm
|
||||
self._unscale_and_clip_grads(total_norm)
|
||||
|
||||
self.optim.step(*args, **kwargs)
|
||||
@@ -212,3 +214,6 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
||||
|
||||
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||
return self._current_grad_norm
|
||||
|
Reference in New Issue
Block a user