mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
fixed zero level 3 dtype bug (#76)
This commit is contained in:
parent
632e622de8
commit
91c327cb44
@ -1,7 +1,6 @@
|
|||||||
from .apex_amp import ApexAMPOptimizer
|
from .apex_amp import ApexAMPOptimizer
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
import apex.amp as apex_amp
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_apex_amp(model: nn.Module,
|
def convert_to_apex_amp(model: nn.Module,
|
||||||
@ -19,6 +18,7 @@ def convert_to_apex_amp(model: nn.Module,
|
|||||||
:return: (model, optimizer)
|
:return: (model, optimizer)
|
||||||
:rtype: Tuple
|
:rtype: Tuple
|
||||||
"""
|
"""
|
||||||
|
import apex.amp as apex_amp
|
||||||
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
|
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
|
||||||
optimizer = ApexAMPOptimizer(optimizer)
|
optimizer = ApexAMPOptimizer(optimizer)
|
||||||
return model, optimizer
|
return model, optimizer
|
||||||
|
@ -30,11 +30,7 @@ def convert_to_zero(model: nn.Module,
|
|||||||
:rtype: Tuple
|
:rtype: Tuple
|
||||||
"""
|
"""
|
||||||
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
|
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
|
||||||
if level == 2:
|
model = NaiveAMPModel(model, output_to_fp32=False)
|
||||||
if is_no_pp_or_last_stage():
|
|
||||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
|
||||||
else:
|
|
||||||
model = NaiveAMPModel(model, output_to_fp32=False)
|
|
||||||
|
|
||||||
if level == 2:
|
if level == 2:
|
||||||
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
|
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
|
||||||
|
@ -695,13 +695,23 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer):
|
|||||||
},
|
},
|
||||||
"aio": aio_config
|
"aio": aio_config
|
||||||
}
|
}
|
||||||
remote_device = offload_param_config['device']
|
|
||||||
|
if offload_param_config is not None:
|
||||||
|
remote_device = offload_param_config['device']
|
||||||
|
else:
|
||||||
|
remote_device = None
|
||||||
|
|
||||||
|
if offload_optimizer_config is not None:
|
||||||
|
pin_memory = offload_optimizer_config.get(OFFLOAD_OPTIMIZER_PIN_MEMORY, False)
|
||||||
|
else:
|
||||||
|
pin_memory = False
|
||||||
|
|
||||||
group = None
|
group = None
|
||||||
if gpc.is_initialized(ParallelMode.DATA):
|
if gpc.is_initialized(ParallelMode.DATA):
|
||||||
group = gpc.get_group(ParallelMode.DATA)
|
group = gpc.get_group(ParallelMode.DATA)
|
||||||
Init(module=module, data_parallel_group=group, dtype=self.dtype,
|
Init(module=module, data_parallel_group=group, dtype=self.dtype,
|
||||||
remote_device=remote_device, config_dict_or_path=ds_config,
|
remote_device=remote_device, config_dict_or_path=ds_config,
|
||||||
pin_memory=offload_optimizer_config[OFFLOAD_OPTIMIZER_PIN_MEMORY])
|
pin_memory=pin_memory)
|
||||||
|
|
||||||
for m in module.modules():
|
for m in module.modules():
|
||||||
_init_external_params(m)
|
_init_external_params(m)
|
||||||
|
@ -89,10 +89,10 @@ def run_dist(rank, world_size):
|
|||||||
model.train()
|
model.train()
|
||||||
for idx, (data, label) in enumerate(train_dataloader):
|
for idx, (data, label) in enumerate(train_dataloader):
|
||||||
engine.zero_grad()
|
engine.zero_grad()
|
||||||
data = data.cuda().half()
|
data = data.cuda()
|
||||||
label = label.cuda()
|
label = label.cuda()
|
||||||
|
|
||||||
output = engine(data).float()
|
output = engine(data)
|
||||||
loss = engine.criterion(output, label)
|
loss = engine.criterion(output, label)
|
||||||
|
|
||||||
engine.backward(loss)
|
engine.backward(loss)
|
||||||
@ -104,7 +104,6 @@ def run_dist(rank, world_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
|
||||||
def test_zero_level_3():
|
def test_zero_level_3():
|
||||||
world_size = 4
|
world_size = 4
|
||||||
run_func = partial(run_dist, world_size=world_size)
|
run_func = partial(run_dist, world_size=world_size)
|
||||||
|
@ -108,7 +108,6 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
|
||||||
def test_3d_vit_zero_level_3():
|
def test_3d_vit_zero_level_3():
|
||||||
world_size = 8
|
world_size = 8
|
||||||
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size)
|
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size)
|
||||||
|
Loading…
Reference in New Issue
Block a user