fixed zero level 3 dtype bug (#76)

This commit is contained in:
Frank Lee
2021-12-20 17:00:53 +08:00
committed by GitHub
parent 632e622de8
commit 91c327cb44
5 changed files with 16 additions and 12 deletions

View File

@@ -30,11 +30,7 @@ def convert_to_zero(model: nn.Module,
:rtype: Tuple
"""
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
if level == 2:
if is_no_pp_or_last_stage():
model = NaiveAMPModel(model, output_to_fp32=True)
else:
model = NaiveAMPModel(model, output_to_fp32=False)
model = NaiveAMPModel(model, output_to_fp32=False)
if level == 2:
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)

View File

@@ -695,13 +695,23 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
},
"aio": aio_config
}
remote_device = offload_param_config['device']
if offload_param_config is not None:
remote_device = offload_param_config['device']
else:
remote_device = None
if offload_optimizer_config is not None:
pin_memory = offload_optimizer_config.get(OFFLOAD_OPTIMIZER_PIN_MEMORY, False)
else:
pin_memory = False
group = None
if gpc.is_initialized(ParallelMode.DATA):
group = gpc.get_group(ParallelMode.DATA)
Init(module=module, data_parallel_group=group, dtype=self.dtype,
remote_device=remote_device, config_dict_or_path=ds_config,
pin_memory=offload_optimizer_config[OFFLOAD_OPTIMIZER_PIN_MEMORY])
pin_memory=pin_memory)
for m in module.modules():
_init_external_params(m)