mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[gemini] support gradient accumulation (#4869)
* add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case
This commit is contained in:
@@ -59,6 +59,7 @@ class GeminiDDP(ModelWrapper):
|
||||
chunk_config_dict: Optional[dict] = None,
|
||||
chunk_init_device: torch.device = torch.device("cpu"),
|
||||
placement_policy: str = "static",
|
||||
enable_gradient_accumulation: bool = False,
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
@@ -119,6 +120,11 @@ class GeminiDDP(ModelWrapper):
|
||||
self.reuse_fp16_chunk = master_weights
|
||||
self.master_weights = master_weights
|
||||
|
||||
self.enable_gradient_accumulation = enable_gradient_accumulation
|
||||
if self.enable_gradient_accumulation:
|
||||
self.reuse_fp16_chunk = False
|
||||
self.accumulating_grads = False # Whether model is accumulating gradients
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
@@ -298,6 +304,8 @@ class GeminiDDP(ModelWrapper):
|
||||
f"{error_str}",
|
||||
)
|
||||
self._setup_grads_ptr()
|
||||
if self.enable_gradient_accumulation and not self.accumulating_grads:
|
||||
self.accumulating_grads = True # Turn on the state of gradient accumulation.
|
||||
self._logger.debug(
|
||||
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
|
||||
)
|
||||
@@ -327,7 +335,15 @@ class GeminiDDP(ModelWrapper):
|
||||
)
|
||||
grad_chunk = chunk
|
||||
if not self.reuse_fp16_chunk:
|
||||
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
|
||||
if not self.accumulating_grads:
|
||||
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
|
||||
else:
|
||||
assert chunk.grad_chunk is not None
|
||||
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
|
||||
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
|
||||
else:
|
||||
grad_chunk = chunk.grad_chunk
|
||||
|
||||
# hold -> compute -> hold after bwd
|
||||
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
|
||||
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
|
||||
@@ -336,7 +352,10 @@ class GeminiDDP(ModelWrapper):
|
||||
chunk.tensor_trans_state(p, TensorState.HOLD)
|
||||
|
||||
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
|
||||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
|
||||
if not self.accumulating_grads:
|
||||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
|
||||
else:
|
||||
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
||||
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
|
||||
if reduced:
|
||||
if not self.reuse_fp16_chunk:
|
||||
@@ -354,7 +373,7 @@ class GeminiDDP(ModelWrapper):
|
||||
if chunk.l2_norm_flag:
|
||||
grad_chunk.set_l2_norm()
|
||||
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
|
||||
if not self.master_weights:
|
||||
if not (self.master_weights) or (self.enable_gradient_accumulation):
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
|
||||
return empty_grad
|
||||
|
||||
|
Reference in New Issue
Block a user