fix zero3 fp16 and add zero3 model context (#62)

This commit is contained in:
ver217
2021-12-10 17:48:50 +08:00
committed by GitHub
parent 9a0466534c
commit 7d3711058f
5 changed files with 114 additions and 11 deletions

View File

@@ -83,4 +83,13 @@ Note that `fp16` is automatically enabled when using ZeRO. This relies on `AMP_T
### Training
Note that if your model is too large to fit within the memory when using ZeRO-3, you should use `colossalai.zero.zero3_model_context` to construct your model:
```python
from colossalai.zero import zero3_model_context
with zero3_model_context():
model = Model()
```
Once you have completed your configuration, just use `colossalai.initialize()` to initialize your training.

View File

@@ -23,7 +23,7 @@ ZeRO优化器可以切分三种模型状态优化器状态、梯度、参数
)
zero = dict(
type='ZeroRedundancyOptimizer_Level_3',
level=3,
dynamic_loss_scale=True,
clip_grad=1.0
)
@@ -78,4 +78,13 @@ ZeRO优化器可以切分三种模型状态优化器状态、梯度、参数
### 使用ZeRO优化器进行训练
注意当使用ZeRO-3时如果您的模型过大以至于无法放入内存, 您应该使用`colossalai.zero.zero3_model_context`来构建您的模型:
```python
from colossalai.zero import zero3_model_context
with zero3_model_context():
model = Model()
```
如果您完成了上述配置,可以运行`colossalai.initialize()`来开始您的训练。