diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 01f508724..4e3012010 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -422,29 +422,29 @@ class GeminiDDP(ModelWrapper): grad_chunk.add_tensor_to_chunk_slice(p, grad) with torch.cuda.stream(async_reduce_stream): - chunk_manager.reduce_chunk(grad_chunk) - - if not chunk_manager.reuse_fp16_chunk: - if chunk.keep_gathered: - chunk_manager.fake_release_chunk(chunk) + reduced = chunk_manager.reduce_chunk(grad_chunk) + if reduced: + if not chunk_manager.reuse_fp16_chunk: + if chunk.keep_gathered: + chunk_manager.fake_release_chunk(chunk) + else: + chunk_manager.release_chunk(chunk) + if grad_chunk.is_gathered: + grad_chunk.cuda_global_chunk.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) else: - chunk_manager.release_chunk(chunk) - if grad_chunk.is_gathered: - grad_chunk.cuda_global_chunk.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) - else: - grad_chunk.cuda_shard.div_(chunk.pg_size) - if chunk.extra_dp_group is not None: - grad_chunk.cuda_shard.div_(chunk.extra_dp_size) - # check overflow elements - chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan - # record l2 norm for gradient clipping. flag is bound to fp16 chunk - if chunk.l2_norm_flag: - grad_chunk.set_l2_norm() - chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) - if not (master_weights) or (enable_gradient_accumulation): - chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + grad_chunk.cuda_shard.div_(chunk.pg_size) + if chunk.extra_dp_group is not None: + grad_chunk.cuda_shard.div_(chunk.extra_dp_size) + # check overflow elements + chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan + # record l2 norm for gradient clipping. flag is bound to fp16 chunk + if chunk.l2_norm_flag: + grad_chunk.set_l2_norm() + chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + if not (master_weights) or (enable_gradient_accumulation): + chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True)