[checkpointio] Sharded Optimizer Checkpoint for Gemini Plugin (#4302)

* sharded optimizer checkpoint for gemini plugin

* modify test to reduce testing time

* update doc

* fix bug when keep_gatherd is true under GeminiPlugin
This commit is contained in:
Baizhou Zhang
2023-07-21 14:39:01 +08:00
committed by GitHub
parent fc5cef2c79
commit c6f6005990
12 changed files with 289 additions and 84 deletions

View File

@@ -21,10 +21,13 @@ Plugin is an important component that manages parallel configuration (eg: The ge
**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
### API of booster
{{ autodoc:colossalai.booster.Booster }}

View File

@@ -21,8 +21,6 @@ Model must be boosted by `colossalai.booster.Booster` before loading. It will de
## Optimizer Checkpoint
> ⚠ Saving optimizer checkpoint in a sharded way is not supported yet.
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
Optimizer must be boosted by `colossalai.booster.Booster` before saving.

View File

@@ -51,8 +51,6 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.
### Torch DDP Plugin
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).