diff --git a/colossalai/amp/apex_amp/__init__.py b/colossalai/amp/apex_amp/__init__.py index 23cffae0f..16e7c36f3 100644 --- a/colossalai/amp/apex_amp/__init__.py +++ b/colossalai/amp/apex_amp/__init__.py @@ -1,7 +1,6 @@ from .apex_amp import ApexAMPOptimizer import torch.nn as nn from torch.optim import Optimizer -import apex.amp as apex_amp def convert_to_apex_amp(model: nn.Module, @@ -19,6 +18,7 @@ def convert_to_apex_amp(model: nn.Module, :return: (model, optimizer) :rtype: Tuple """ + import apex.amp as apex_amp model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) optimizer = ApexAMPOptimizer(optimizer) return model, optimizer diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index e464be5ce..1eb03d1a7 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -30,11 +30,7 @@ def convert_to_zero(model: nn.Module, :rtype: Tuple """ assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided' - if level == 2: - if is_no_pp_or_last_stage(): - model = NaiveAMPModel(model, output_to_fp32=True) - else: - model = NaiveAMPModel(model, output_to_fp32=False) + model = NaiveAMPModel(model, output_to_fp32=False) if level == 2: optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config) diff --git a/colossalai/zero/zero_redundancy_optimizer_level_3.py b/colossalai/zero/zero_redundancy_optimizer_level_3.py index 32c5afba8..34051e638 100644 --- a/colossalai/zero/zero_redundancy_optimizer_level_3.py +++ b/colossalai/zero/zero_redundancy_optimizer_level_3.py @@ -695,13 +695,23 @@ class ZeroRedundancyOptimizer_Level_3(Optimizer): }, "aio": aio_config } - remote_device = offload_param_config['device'] + + if offload_param_config is not None: + remote_device = offload_param_config['device'] + else: + remote_device = None + + if offload_optimizer_config is not None: + pin_memory = offload_optimizer_config.get(OFFLOAD_OPTIMIZER_PIN_MEMORY, False) + else: + pin_memory = False + group = None if gpc.is_initialized(ParallelMode.DATA): group = gpc.get_group(ParallelMode.DATA) Init(module=module, data_parallel_group=group, dtype=self.dtype, remote_device=remote_device, config_dict_or_path=ds_config, - pin_memory=offload_optimizer_config[OFFLOAD_OPTIMIZER_PIN_MEMORY]) + pin_memory=pin_memory) for m in module.modules(): _init_external_params(m) diff --git a/tests/test_zero_data_parallel/test_zero_level_3.py b/tests/test_zero_data_parallel/test_zero_level_3.py index e202b31e4..f1fe45b2b 100644 --- a/tests/test_zero_data_parallel/test_zero_level_3.py +++ b/tests/test_zero_data_parallel/test_zero_level_3.py @@ -89,10 +89,10 @@ def run_dist(rank, world_size): model.train() for idx, (data, label) in enumerate(train_dataloader): engine.zero_grad() - data = data.cuda().half() + data = data.cuda() label = label.cuda() - output = engine(data).float() + output = engine(data) loss = engine.criterion(output, label) engine.backward(loss) @@ -104,7 +104,6 @@ def run_dist(rank, world_size): @pytest.mark.dist -@pytest.mark.skip("Level 3 has unknown bug so skip this test for now") def test_zero_level_3(): world_size = 4 run_func = partial(run_dist, world_size=world_size) diff --git a/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py b/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py index 38c9d8a70..275ff1997 100644 --- a/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py +++ b/tests/test_zero_tensor_parallel/test_vit_2d_level_3.py @@ -108,7 +108,6 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size): @pytest.mark.dist -@pytest.mark.skip("Level 3 has unknown bug so skip this test for now") def test_3d_vit_zero_level_3(): world_size = 8 run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size)