From ee112fe1dad8738b498c963c5f64dde4e0a0399f Mon Sep 17 00:00:00 2001 From: HELSON Date: Fri, 8 Apr 2022 20:23:26 +0800 Subject: [PATCH] [zero] adapt zero hooks for unsharded module (#699) --- colossalai/engine/ophooks/zero_hook.py | 59 ++++++++++++------- colossalai/zero/init_ctx/init_context.py | 26 ++++---- .../zero/sharded_model/sharded_model_v2.py | 6 +- .../zero/sharded_optim/sharded_optim_v2.py | 8 +-- .../zero/sharded_param/sharded_param.py | 7 ++- tests/test_moe/test_moe_zero_init.py | 1 - tests/test_moe/test_moe_zero_model.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 4 +- tests/test_zero_data_parallel/common.py | 7 +-- .../test_shard_model_v2.py | 2 +- .../test_shard_param.py | 6 +- .../test_state_dict.py | 2 +- 12 files changed, 71 insertions(+), 59 deletions(-) diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index b0ab82a94..a43cf1878 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -36,6 +36,7 @@ class ZeroHook(BaseOpHook): self._stateful_tensor_mgr = stateful_tensor_mgr def pre_fwd_exec(self, module: torch.nn.Module, *args): + for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) @@ -45,12 +46,15 @@ class ZeroHook(BaseOpHook): for param in module.parameters(recurse=False): colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) - tensor_list = [] - for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') - tensor_list.append(param.colo_attr.sharded_data_tensor) - self.shard_strategy.gather(tensor_list, self.process_group) + # gather sharded parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.gather(tensor_list, self.process_group) + # record memory statistics if self._memstarts_collector: self._memstarts_collector.sample_memstats() @@ -59,18 +63,25 @@ class ZeroHook(BaseOpHook): assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" def post_fwd_exec(self, module: torch.nn.Module, *args): + + # change tensor state to HOLD_AFTER_FWD for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) - tensor_list = [] - for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') - tensor_list.append(param.colo_attr.sharded_data_tensor) - self.shard_strategy.shard(tensor_list, self.process_group) + # shard gathered parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.shard(tensor_list, self.process_group) + + # remove torch payload for param in module.parameters(recurse=False): param.colo_attr.remove_torch_payload() def pre_bwd_exec(self, module: torch.nn.Module, input, output): + for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) @@ -80,12 +91,15 @@ class ZeroHook(BaseOpHook): for param in module.parameters(recurse=False): colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) - tensor_list = [] - for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') - tensor_list.append(param.colo_attr.sharded_data_tensor) - self.shard_strategy.gather(tensor_list, self.process_group) + # gather sharded parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.gather(tensor_list, self.process_group) + # record memory statistics if self._memstarts_collector: self._memstarts_collector.sample_memstats() @@ -94,15 +108,20 @@ class ZeroHook(BaseOpHook): assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" def post_bwd_exec(self, module: torch.nn.Module, input): + + # change tensor state to HOLD_AFTER_BWD for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) - tensor_list = [] - for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') - tensor_list.append(param.colo_attr.sharded_data_tensor) - self.shard_strategy.shard(tensor_list, self.process_group) + # shard gathered parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.shard(tensor_list, self.process_group) + # remove torch payload for param in module.parameters(recurse=False): param.colo_attr.remove_torch_payload() diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index f5efe3d11..1a2fce2da 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -135,8 +135,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): super().__init__() self.shard_strategy = shard_strategy - self.sharded_param_list = [] - self.unshard_param_list = [] + self.param_list = [] self.model_numel_tensor = model_numel_tensor self.seed = seed self.dp_process_group = gpc.get_group(ParallelMode.DATA) @@ -210,19 +209,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): def _post_context_exec(self): """The callback function when exiting context. """ - for param in self.sharded_param_list: + # broadcast replicated no-shard parameters + src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] + for param in self.param_list: assert hasattr(param, 'colo_attr') + if not param.colo_attr.param_is_sharded and param.is_replicated: + dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) param.colo_attr.remove_torch_payload() - del self.sharded_param_list - - src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] - for param in self.unshard_param_list: - assert hasattr(param, 'colo_attr') - if param.is_replicated: - dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) - - del self.unshard_param_list + del self.param_list nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout torch.set_rng_state(self.cpu_rng_state) @@ -264,10 +259,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) - param.data = param.colo_attr.sharded_data_tensor.payload - self.sharded_param_list.append(param) - else: - self.unshard_param_list.append(param) + param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload + + self.param_list.append(param) # We must cast buffers # If we use BN, buffers may be on CPU and Float diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 9f05eb363..199ec882e 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -121,7 +121,7 @@ class ShardedModelV2(nn.Module): self._ophook_list = [ ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group) ] - register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded) + register_ophooks_recursively(self.module, self._ophook_list) self.param_hook_mgr = BaseParamHookMgr(self.sharded_params) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) @@ -366,14 +366,12 @@ class ShardedModelV2(nn.Module): def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) - prev_params = {} for p in self.sharded_params: - prev_params[p] = p.data p.data = p.colo_attr.sharded_data_tensor.payload gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group) for p in self.sharded_params: - p.data = prev_params[p] + p.colo_attr.remove_torch_payload() return gathered_state_dict def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 31f58b9e0..bd708dad3 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -268,10 +268,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): p.data = self.master_params[p].payload p.colo_attr.sharded_data_tensor.reset_payload( colo_model_tensor_clone(p.half(), torch.cuda.current_device())) - - if not p.colo_attr.param_is_sharded: - # FIXME(hhc): add hook for unsharded parameters - p.data = p.colo_attr.sharded_data_tensor.payload + p.colo_attr.remove_torch_payload() def sync_grad(self): pass @@ -351,10 +348,11 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # TODO() optimize this line CPU (fp32) -> GPU (fp16) p.colo_attr.sharded_data_tensor.reset_payload( colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device)) + p.colo_attr.remove_torch_payload() if not is_param_sharded and not self.keep_unshard: # We gather full fp16 param here self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) - p.data = p.colo_attr.sharded_data_tensor.payload + self.master_params[p].trans_state(TensorState.HOLD) p.colo_attr.saved_grad.set_null() diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 92f5bb59c..dff933a83 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -5,6 +5,11 @@ from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage from .tensorful_state import StatefulTensor, TensorState from typing import List +# use this tensor as empty data point for parameters +# we do not want users use param.data when its torch payload is removed +# empty tensor is expected to raise error when get used +FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu') + class ShardedParamV2(object): @@ -29,7 +34,7 @@ class ShardedParamV2(object): return [self._sharded_data_tensor] def remove_torch_payload(self): - self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device) + self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype) @property def sharded_data_tensor(self): diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 3c308a421..45dc061c7 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -66,7 +66,6 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): # the parameters in moe experts and its gate should not be sharded if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): assert not param.colo_attr.sharded_data_tensor.is_sharded - assert param.colo_attr.sharded_data_tensor.data_ptr() == param.data.data_ptr() else: assert param.colo_attr.sharded_data_tensor.is_sharded diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 34c43bc06..336117f7a 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -37,7 +37,7 @@ def run_model_test(enable_autocast, shard_strategy_class): # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): if not p.colo_attr.param_is_sharded and p.is_replicated: - assert_equal_in_group(p.data) + assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload) model = MoeModel().half() col_model_deepcopy(zero_model, model) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 5a5063d6c..aa1ac57bc 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -74,7 +74,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): if not p.colo_attr.param_is_sharded and p.is_replicated: - assert_equal_in_group(p.data.to(get_current_device())) + assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device())) model = MoeModel().half() col_model_deepcopy(zero_model, model) @@ -99,7 +99,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()): if 'gate' in n: p.data = p.float() - p.data.copy_(zp.data) + p.data.copy_(zp.colo_attr.sharded_data_tensor.payload) for i, (data, label) in enumerate(train_dataloader): if i > 5: diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 0143c0e3c..a35ed0060 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -126,16 +126,15 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard= rank = dist.get_rank() for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): if zero_p.colo_attr.param_is_sharded: - if reuse_fp16_shard: - zero_p = zero_p.data.to(p.device).float() - else: - zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float() + zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float() chunks = torch.flatten(p).chunk(dist.get_world_size()) if rank >= len(chunks): continue p = chunks[rank].float() if zero_p.size(0) > p.size(0): zero_p = zero_p[:p.size(0)] + else: + zero_p = zero_p.colo_attr.sharded_data_tensor.payload assert p.dtype == zero_p.dtype assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' diff --git a/tests/test_zero_data_parallel/test_shard_model_v2.py b/tests/test_zero_data_parallel/test_shard_model_v2.py index 449ceedc0..bf84fd29a 100644 --- a/tests/test_zero_data_parallel/test_shard_model_v2.py +++ b/tests/test_zero_data_parallel/test_shard_model_v2.py @@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd @parameterize("enable_autocast", [True]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) def run_model_test(enable_autocast, shard_strategy_class): test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module'] shard_strategy = shard_strategy_class() diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 2d2ae1075..79780fa51 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -58,15 +58,15 @@ def _run_shard_param_v2(rank, world_size, port): assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" sparam.remove_torch_payload() - assert (param.data.numel() == 1) + assert (param.data.numel() == 0) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() # 4 is size of dummy tensor of param.data - assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 + assert cpu_mem_use == 2 * 3 * 4 * 2 sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) sparam.remove_torch_payload() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 + assert cpu_mem_use == 2 * 3 * 4 * 2 assert cuda_mem_use == 0 # append a grad to torch param diff --git a/tests/test_zero_data_parallel/test_state_dict.py b/tests/test_zero_data_parallel/test_state_dict.py index 41cae05a5..818d05ffd 100644 --- a/tests/test_zero_data_parallel/test_state_dict.py +++ b/tests/test_zero_data_parallel/test_state_dict.py @@ -56,4 +56,4 @@ def test_zero_state_dict(world_size): if __name__ == '__main__': - test_zero_state_dict(2, TensorShardStrategy) + test_zero_state_dict(2)