[gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case
This commit is contained in:
Baizhou Zhang
2023-10-17 14:07:21 +08:00
committed by GitHub
parent a41cf88e9b
commit 21ba89cab6
11 changed files with 283 additions and 10 deletions

View File

@@ -59,6 +59,7 @@ class GeminiDDP(ModelWrapper):
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@@ -119,6 +120,11 @@ class GeminiDDP(ModelWrapper):
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
self.enable_gradient_accumulation = enable_gradient_accumulation
if self.enable_gradient_accumulation:
self.reuse_fp16_chunk = False
self.accumulating_grads = False # Whether model is accumulating gradients
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
@@ -298,6 +304,8 @@ class GeminiDDP(ModelWrapper):
f"{error_str}",
)
self._setup_grads_ptr()
if self.enable_gradient_accumulation and not self.accumulating_grads:
self.accumulating_grads = True # Turn on the state of gradient accumulation.
self._logger.debug(
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
)
@@ -327,7 +335,15 @@ class GeminiDDP(ModelWrapper):
)
grad_chunk = chunk
if not self.reuse_fp16_chunk:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
if not self.accumulating_grads:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
else:
assert chunk.grad_chunk is not None
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
@@ -336,7 +352,10 @@ class GeminiDDP(ModelWrapper):
chunk.tensor_trans_state(p, TensorState.HOLD)
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
if not self.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not self.reuse_fp16_chunk:
@@ -354,7 +373,7 @@ class GeminiDDP(ModelWrapper):
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not self.master_weights:
if not (self.master_weights) or (self.enable_gradient_accumulation):
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad