mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
fixed zero level 3 dtype bug (#76)
This commit is contained in:
@@ -30,11 +30,7 @@ def convert_to_zero(model: nn.Module,
|
||||
:rtype: Tuple
|
||||
"""
|
||||
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
|
||||
if level == 2:
|
||||
if is_no_pp_or_last_stage():
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
else:
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
|
||||
if level == 2:
|
||||
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
|
||||
|
Reference in New Issue
Block a user