[gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case
This commit is contained in:
Baizhou Zhang
2023-10-17 14:07:21 +08:00
committed by GitHub
parent a41cf88e9b
commit 21ba89cab6
11 changed files with 283 additions and 10 deletions

View File

@@ -245,6 +245,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False.
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
@@ -257,7 +258,7 @@ class GeminiPlugin(DPPluginBase):
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
master_weights (bool, optional): master weights. Defaults to True.
master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
@@ -291,6 +292,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@@ -323,6 +325,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,

View File

@@ -335,4 +335,4 @@ class LowLevelZeroPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.optim.no_sync()
return optimizer.no_sync()

View File

@@ -434,6 +434,21 @@ class Chunk:
if update_ptr:
tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape)
def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Add data slice to the memory space indexed by the input tensor in the chunk.
Only used when accumulating gradient chunks.
Args:
tensor (torch.Tensor): the tensor used to retrieve meta information
data_slice (torch.Tensor): the tensor to be added to the chunk
"""
# sanity check
assert self.is_gathered
tensor_info = self.tensors_info[tensor]
self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten())
def get_valid_length(self) -> int:
"""Get the valid length of the chunk's payload."""
if self.keep_gathered:

View File

@@ -5,7 +5,7 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.utils import get_current_device
from colossalai.utils import free_storage, get_current_device
from .chunk import Chunk, ChunkFullError, TensorState
@@ -255,3 +255,37 @@ class ChunkManager:
self.accessed_chunks.add(grad_chunk)
self.accessed_mem += grad_chunk.chunk_mem
return grad_chunk
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""
assert chunk.grad_chunk is not None
# Make a backup for gradient accumulated before.
# Here backup gradients should be multiplied, since it will be divided after gradient reduction.
if chunk.grad_chunk.is_gathered:
accumulated_grad = chunk.grad_chunk.cuda_global_chunk.clone().detach().mul_(chunk.pg_size)
accumulated_grad_gathered = True
else:
if chunk.grad_chunk.cuda_shard is not None:
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
else:
accumulated_grad = (
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
)
accumulated_grad_gathered = False
# Reset grad_chunk, and chunk.grad_chunk will be accessed.
grad_chunk = self.init_grad_chunk(chunk)
grad_chunk.cuda_global_chunk.zero_()
# Add backup gradients to grad_chunk.
if accumulated_grad_gathered:
grad_chunk.cuda_global_chunk.add_(accumulated_grad)
else:
grad_chunk.cuda_global_chunk[grad_chunk.shard_begin : grad_chunk.shard_end].add_(accumulated_grad)
# Release accumulated_grad
free_storage(accumulated_grad)
return grad_chunk

View File

@@ -59,6 +59,7 @@ class GeminiDDP(ModelWrapper):
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@@ -119,6 +120,11 @@ class GeminiDDP(ModelWrapper):
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
self.enable_gradient_accumulation = enable_gradient_accumulation
if self.enable_gradient_accumulation:
self.reuse_fp16_chunk = False
self.accumulating_grads = False # Whether model is accumulating gradients
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
@@ -298,6 +304,8 @@ class GeminiDDP(ModelWrapper):
f"{error_str}",
)
self._setup_grads_ptr()
if self.enable_gradient_accumulation and not self.accumulating_grads:
self.accumulating_grads = True # Turn on the state of gradient accumulation.
self._logger.debug(
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
)
@@ -327,7 +335,15 @@ class GeminiDDP(ModelWrapper):
)
grad_chunk = chunk
if not self.reuse_fp16_chunk:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
if not self.accumulating_grads:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
else:
assert chunk.grad_chunk is not None
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD)
@@ -336,7 +352,10 @@ class GeminiDDP(ModelWrapper):
chunk.tensor_trans_state(p, TensorState.HOLD)
grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
if not self.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not self.reuse_fp16_chunk:
@@ -354,7 +373,7 @@ class GeminiDDP(ModelWrapper):
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not self.master_weights:
if not (self.master_weights) or (self.enable_gradient_accumulation):
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad

View File

@@ -263,6 +263,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.zero_grad()
if self.module.master_weights:
self._update_fp16_params()
self.module.accumulating_grads = False
return ret
def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):