mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
fixed zero level 3 dtype bug (#76)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from .apex_amp import ApexAMPOptimizer
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
import apex.amp as apex_amp
|
||||
|
||||
|
||||
def convert_to_apex_amp(model: nn.Module,
|
||||
@@ -19,6 +18,7 @@ def convert_to_apex_amp(model: nn.Module,
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
import apex.amp as apex_amp
|
||||
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
|
||||
optimizer = ApexAMPOptimizer(optimizer)
|
||||
return model, optimizer
|
||||
|
@@ -30,11 +30,7 @@ def convert_to_zero(model: nn.Module,
|
||||
:rtype: Tuple
|
||||
"""
|
||||
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
|
||||
if level == 2:
|
||||
if is_no_pp_or_last_stage():
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
else:
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||
|
||||
if level == 2:
|
||||
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
|
||||
|
@@ -695,13 +695,23 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
|
||||
},
|
||||
"aio": aio_config
|
||||
}
|
||||
remote_device = offload_param_config['device']
|
||||
|
||||
if offload_param_config is not None:
|
||||
remote_device = offload_param_config['device']
|
||||
else:
|
||||
remote_device = None
|
||||
|
||||
if offload_optimizer_config is not None:
|
||||
pin_memory = offload_optimizer_config.get(OFFLOAD_OPTIMIZER_PIN_MEMORY, False)
|
||||
else:
|
||||
pin_memory = False
|
||||
|
||||
group = None
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
group = gpc.get_group(ParallelMode.DATA)
|
||||
Init(module=module, data_parallel_group=group, dtype=self.dtype,
|
||||
remote_device=remote_device, config_dict_or_path=ds_config,
|
||||
pin_memory=offload_optimizer_config[OFFLOAD_OPTIMIZER_PIN_MEMORY])
|
||||
pin_memory=pin_memory)
|
||||
|
||||
for m in module.modules():
|
||||
_init_external_params(m)
|
||||
|
Reference in New Issue
Block a user