mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[plugin] support get_grad_norm (#6115)
This commit is contained in:
@@ -218,6 +218,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
)
|
||||
elif self._dtype is torch.bfloat16:
|
||||
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||
self._current_grad_norm: Optional[float] = None
|
||||
|
||||
def __del__(self):
|
||||
for hook in self.grad_handles:
|
||||
@@ -551,6 +552,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._current_grad_norm = global_norm
|
||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||
|
||||
# update the parameters
|
||||
@@ -934,3 +936,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
def _force_wait_all_gather(self):
|
||||
for param in self._working_param_to_padded_working_param.keys():
|
||||
wait_all_gather_handle(param)
|
||||
|
||||
def get_grad_norm(self, norm_type=2, **kwargs):
|
||||
return self._current_grad_norm
|
||||
|
Reference in New Issue
Block a user