[plugin] support get_grad_norm (#6115)

This commit is contained in:
Hongxin Liu
2024-11-05 18:12:47 +08:00
committed by GitHub
parent 13ffa08cfa
commit a15ab139ad
8 changed files with 40 additions and 2 deletions

View File

@@ -1,7 +1,7 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import math
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union
import torch
import torch.distributed as dist
@@ -195,6 +195,7 @@ class GeminiOptimizer(OptimizerWrapper):
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
self._register_states = disposable(self._register_states_)
self._current_grad_norm: Optional[float] = None
def _set_grad_ptr(self):
for group in self.param_groups:
@@ -255,6 +256,7 @@ class GeminiOptimizer(OptimizerWrapper):
if self.clipping_flag:
total_norm = self._calc_global_norm()
self._current_grad_norm = total_norm
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
@@ -846,6 +848,9 @@ class GeminiOptimizer(OptimizerWrapper):
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)
def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: