From 5521af7877cebf1f3147dd9d60224e20a3733b8f Mon Sep 17 00:00:00 2001 From: HELSON Date: Wed, 11 Jan 2023 14:55:41 +0800 Subject: [PATCH] [zero] fix state_dict and load_state_dict for ddp ignored parameters (#2443) * [ddp] add is_ddp_ignored [ddp] rename to is_ddp_ignored * [zero] fix state_dict and load_state_dict * fix bugs * [zero] update unit test for ZeroDDP --- colossalai/nn/parallel/data_parallel.py | 22 +++++++++++++++---- .../update/test_zeroddp_state_dict.py | 12 ++++++++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 649bd920d..28a10c4b6 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -233,7 +233,7 @@ class ZeroDDP(ColoDDP): assert isinstance(p, ColoParameter) if is_ddp_ignored(p): - p.data = p.data.half() + p.data = p.data.to(device=get_current_device(), dtype=torch.float16) continue fp32_data = p.data.float() @@ -451,8 +451,14 @@ class ZeroDDP(ColoDDP): assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) - # TODO: (HELSON) deal with ddp ignored parameters - for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + ddp_param_list = [] + for name, param in self.named_parameters(): + if is_ddp_ignored(param): + # deal with ddp ignored parameters + destination[prefix + name] = param if keep_vars else param.detach() + else: + ddp_param_list.append((name, param)) + for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): if p is not None: assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) record_parameter = param_to_save_data[fp32_p] @@ -588,8 +594,16 @@ class ZeroDDP(ColoDDP): def load_fp32_parameter(chunk_slice, data): chunk_slice.copy_(data.flatten()) + ddp_param_list = [] + for name, param in self.named_parameters(): + if is_ddp_ignored(param): + # deal with ddp ignored parameters + load(name, param, param.copy_) + else: + ddp_param_list.append((name, param)) + fp32_to_name = dict() - for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): + for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): if p is not None: fp32_to_name[fp32_p] = name diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index b902bb0f0..266b8eab1 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -4,6 +4,7 @@ import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from torch.testing import assert_close import colossalai from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration @@ -17,6 +18,13 @@ from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed +def ignore_the_first_parameter(model: torch.nn.Module): + for name, param in model.named_parameters(): + print(f"parameter `{name}` is set ignored") + ZeroDDP.set_params_to_ignore([param]) + return + + @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @parameterize('keep_gathered', [True, False]) @parameterize('model_name', ['gpt2', 'bert']) @@ -47,7 +55,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): for key, value in torch_dict.items(): assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) @@ -84,7 +92,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): for key, value in torch_dict.items(): assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) - assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key) + assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) def run_dist(rank, world_size, port):