mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[doc] Update booster user documents. (#4669)
* update booster_api.md * update booster_checkpoint.md * update booster_plugins.md * move transformers importing inside function * fix Dict typing * fix autodoc bug * small fix
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Booster API
|
||||
|
||||
Author: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https://github.com/CjhHa1)
|
||||
Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003)
|
||||
|
||||
**Prerequisite:**
|
||||
|
||||
@@ -9,32 +9,35 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https:/
|
||||
|
||||
**Example Code**
|
||||
|
||||
- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
|
||||
- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
|
||||
|
||||
## Introduction
|
||||
|
||||
In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, I will cover how `colossalai.booster` works and what we should take note of.
|
||||
In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also, calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, we will cover how `colossalai.booster` works and what we should take note of.
|
||||
|
||||
### Plugin
|
||||
|
||||
Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:
|
||||
|
||||
**_HybridParallelPlugin:_** This plugin wraps the hybrid parallel training acceleration solution. It provides an interface for any combination of tensor parallel, pipeline parallel and data parallel strategies including DDP and ZeRO.
|
||||
|
||||
**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
|
||||
|
||||
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
|
||||
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallel at the module level which can run across multiple machines.
|
||||
|
||||
**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
|
||||
|
||||
|
||||
**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
|
||||
|
||||
More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md).
|
||||
|
||||
### API of booster
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster }}
|
||||
|
||||
## Usage
|
||||
|
||||
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `colossalai.booster` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
|
||||
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `booster.boost` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
|
||||
|
||||
A pseudo-code example is like below:
|
||||
|
||||
@@ -48,15 +51,21 @@ from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
|
||||
def train():
|
||||
# launch colossalai
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
# create plugin and objects for training
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = SGD((model.parameters()), lr=0.001)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
|
||||
|
||||
# use booster.boost to wrap the training objects
|
||||
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
|
||||
|
||||
# do training as normal, except that the backward should be called by booster
|
||||
x = torch.randn(4, 3, 224, 224)
|
||||
x = x.to('cuda')
|
||||
output = model(x)
|
||||
@@ -65,14 +74,16 @@ def train():
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# checkpointing using booster api
|
||||
save_path = "./model"
|
||||
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
|
||||
booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)
|
||||
|
||||
new_model = resnet18()
|
||||
booster.load_model(new_model, save_path)
|
||||
```
|
||||
|
||||
[more design details](https://github.com/hpcaitech/ColossalAI/discussions/3046)
|
||||
For more design details please see [this page](https://github.com/hpcaitech/ColossalAI/discussions/3046).
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 booster_api.py -->
|
||||
|
@@ -13,7 +13,7 @@ We've introduced the [Booster API](./booster_api.md) in the previous tutorial. I
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.save_model }}
|
||||
|
||||
Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers).
|
||||
Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers), so you can use huggingface `from_pretrained` method to load model from our sharded checkpoint.
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.load_model }}
|
||||
|
||||
|
@@ -1,6 +1,6 @@
|
||||
# Booster Plugins
|
||||
|
||||
Author: [Hongxin Liu](https://github.com/ver217)
|
||||
Author: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003)
|
||||
|
||||
**Prerequisite:**
|
||||
- [Booster API](./booster_api.md)
|
||||
@@ -15,6 +15,7 @@ We currently provide the following plugins:
|
||||
- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management.
|
||||
- [Torch DDP Plugin](#torch-ddp-plugin): It is a wrapper of `torch.nn.parallel.DistributedDataParallel` and can be used to train models with data parallelism.
|
||||
- [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp.
|
||||
- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below.
|
||||
|
||||
More plugins are coming soon.
|
||||
|
||||
@@ -43,8 +44,6 @@ We've tested compatibility on some famous models, following models may not be su
|
||||
|
||||
Compatibility problems will be fixed in the future.
|
||||
|
||||
> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.
|
||||
|
||||
### Gemini Plugin
|
||||
|
||||
This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](../features/zero_with_chunk.md).
|
||||
@@ -69,4 +68,24 @@ More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.h
|
||||
|
||||
{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}
|
||||
|
||||
|
||||
### Hybrid Parallel Plugin
|
||||
|
||||
This plugin implements the combination of various parallel training strategies and optimization tools. The features of HybridParallelPlugin can be generally divided into four parts:
|
||||
|
||||
1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer.
|
||||
|
||||
2. Mixed Precision Training: Support for fp16/bf16 mixed precision training. More details about its arguments configuration can be found in [Mixed Precision Training Doc](../features/mixed_precision_training_with_booster.md).
|
||||
|
||||
3. Torch DDP: This plugin will automatically adopt Pytorch DDP as data parallel strategy when pipeline parallel and Zero is not used. More details about its arguments configuration can be found in [Pytorch DDP Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
|
||||
|
||||
4. Zero: This plugin can adopt Zero 1/2 as data parallel strategy through setting the `zero_stage` argument as 1 or 2 when initializing plugin. Zero 1 is compatible with pipeline parallel strategy, while Zero 2 is not. More details about its argument configuration can be found in [Low Level Zero Plugin](#low-level-zero-plugin).
|
||||
|
||||
> ⚠ When using this plugin, only the subset of Huggingface transformers supported by Shardformer are compatible with tensor parallel, pipeline parallel and optimization tools. Mainstream transformers such as Llama 1, Llama 2, OPT, Bloom, Bert and GPT2 etc. are all supported by Shardformer.
|
||||
|
||||
> ⚠ This plugin only supports sharded checkpointing methods for model/optimizer at present. Unsharded checkpointing methods will be supported in future release.
|
||||
|
||||
{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
|
||||
|
||||
|
||||
<!-- doc-test-command: echo -->
|
||||
|
Reference in New Issue
Block a user