fix zero3 fp16 and add zero3 model context (#62)

This commit is contained in:
ver217
2021-12-10 17:48:50 +08:00
committed by GitHub
parent 9a0466534c
commit 7d3711058f
5 changed files with 114 additions and 11 deletions

View File

@@ -220,7 +220,9 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# first sync model across dp ranks
model.to(get_current_device())
sync_model_param_in_dp(model)
use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3
if not use_zero3:
sync_model_param_in_dp(model)
# check amp and zero
fp16_cfg = gpc.config.get('fp16', None)