fixed zero level 3 dtype bug (#76)

This commit is contained in:
Frank Lee
2021-12-20 17:00:53 +08:00
committed by GitHub
parent 632e622de8
commit 91c327cb44
5 changed files with 16 additions and 12 deletions

View File

@@ -30,11 +30,7 @@ def convert_to_zero(model: nn.Module,
:rtype: Tuple
"""
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
if level == 2:
if is_no_pp_or_last_stage():
model = NaiveAMPModel(model, output_to_fp32=True)
else:
model = NaiveAMPModel(model, output_to_fp32=False)
model = NaiveAMPModel(model, output_to_fp32=False)
if level == 2:
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)