[gemini]remove registered gradients hooks (#5696)

* fix gemini

fix gemini

* fix

fix
This commit is contained in:
flybird11111
2024-05-09 10:29:49 +08:00
committed by GitHub
parent 22297789ab
commit d4c5ef441e
5 changed files with 93 additions and 46 deletions

View File

@@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
self.module = module
def check_local_overflow(self) -> bool:
return self.module.overflow_counter > 0
return self.module.chunk_manager.overflow_counter > 0
def pre_zero_grad(self) -> None:
self.module.overflow_counter = 0
self.module.chunk_manager.overflow_counter = 0
class GeminiOptimizer(OptimizerWrapper):
@@ -202,7 +202,7 @@ class GeminiOptimizer(OptimizerWrapper):
chunk16 = self.param_to_chunk16[fake_param]
begin, end = self.param_to_range[fake_param]
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
grad_chunk16 = chunk16 if self.module.chunk_manager.reuse_fp16_chunk else chunk16.grad_chunk
fake_param.data = grad_chunk16.payload[begin:end]
fake_param.grad = fake_param.data
@@ -221,14 +221,14 @@ class GeminiOptimizer(OptimizerWrapper):
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
grad_chunk.l2_norm = None
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
for c16 in self.chunk16_set:
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
assert grad_chunk.l2_norm is not None
if grad_chunk.is_gathered:
@@ -275,7 +275,7 @@ class GeminiOptimizer(OptimizerWrapper):
self._logger.info(f"Found overflow. Skip step")
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
if self.module.reuse_fp16_chunk:
if self.module.chunk_manager.reuse_fp16_chunk:
self._update_fp16_params()
return
@@ -288,7 +288,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.zero_grad()
if self.module.master_weights:
self._update_fp16_params()
self.module.accumulating_grads = False
self.module.chunk_manager.accumulating_grads = False
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):