[plugin] support get_grad_norm (#6115)

This commit is contained in:
Hongxin Liu
2024-11-05 18:12:47 +08:00
committed by GitHub
parent 13ffa08cfa
commit a15ab139ad
8 changed files with 40 additions and 2 deletions

View File

@@ -218,6 +218,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
self._current_grad_norm: Optional[float] = None
def __del__(self):
for hook in self.grad_handles:
@@ -551,6 +552,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._current_grad_norm = global_norm
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# update the parameters
@@ -934,3 +936,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm