[gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case
This commit is contained in:
Baizhou Zhang
2023-10-17 14:07:21 +08:00
committed by GitHub
parent a41cf88e9b
commit 21ba89cab6
11 changed files with 283 additions and 10 deletions

View File

@@ -245,6 +245,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False.
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
@@ -257,7 +258,7 @@ class GeminiPlugin(DPPluginBase):
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
master_weights (bool, optional): master weights. Defaults to True.
master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
@@ -291,6 +292,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@@ -323,6 +325,7 @@ class GeminiPlugin(DPPluginBase):
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
enable_gradient_accumulation=enable_gradient_accumulation,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,

View File

@@ -335,4 +335,4 @@ class LowLevelZeroPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.optim.no_sync()
return optimizer.no_sync()