mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[gemini]remove registered gradients hooks (#5696)
* fix gemini fix gemini * fix fix
This commit is contained in:
@@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
self.module = module
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
return self.module.overflow_counter > 0
|
||||
return self.module.chunk_manager.overflow_counter > 0
|
||||
|
||||
def pre_zero_grad(self) -> None:
|
||||
self.module.overflow_counter = 0
|
||||
self.module.chunk_manager.overflow_counter = 0
|
||||
|
||||
|
||||
class GeminiOptimizer(OptimizerWrapper):
|
||||
@@ -202,7 +202,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
chunk16 = self.param_to_chunk16[fake_param]
|
||||
begin, end = self.param_to_range[fake_param]
|
||||
|
||||
grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
|
||||
grad_chunk16 = chunk16 if self.module.chunk_manager.reuse_fp16_chunk else chunk16.grad_chunk
|
||||
fake_param.data = grad_chunk16.payload[begin:end]
|
||||
fake_param.grad = fake_param.data
|
||||
|
||||
@@ -221,14 +221,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def _clear_global_norm(self) -> None:
|
||||
for c16 in self.chunk16_set:
|
||||
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
|
||||
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
|
||||
grad_chunk.l2_norm = None
|
||||
|
||||
def _calc_global_norm(self) -> float:
|
||||
norm_sqr: float = 0.0
|
||||
group_to_norm = dict()
|
||||
for c16 in self.chunk16_set:
|
||||
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
|
||||
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
|
||||
assert grad_chunk.l2_norm is not None
|
||||
|
||||
if grad_chunk.is_gathered:
|
||||
@@ -275,7 +275,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self._logger.info(f"Found overflow. Skip step")
|
||||
self._clear_global_norm() # clear recorded norm
|
||||
self.zero_grad() # reset all gradients
|
||||
if self.module.reuse_fp16_chunk:
|
||||
if self.module.chunk_manager.reuse_fp16_chunk:
|
||||
self._update_fp16_params()
|
||||
return
|
||||
|
||||
@@ -288,7 +288,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
self.zero_grad()
|
||||
if self.module.master_weights:
|
||||
self._update_fp16_params()
|
||||
self.module.accumulating_grads = False
|
||||
self.module.chunk_manager.accumulating_grads = False
|
||||
return ret
|
||||
|
||||
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
|
||||
|
Reference in New Issue
Block a user