diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 24d59e177..a30416ab9 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -5,6 +5,7 @@ from typing import Dict, Iterable, List, Optional, Set import torch import torch.distributed as dist +import torch.nn as nn from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState from colossalai.gemini.gemini_mgr import GeminiManager @@ -218,11 +219,15 @@ class ZeroDDP(ColoDDP): self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(gemini_manager) - self.fp32_params: List[ColoTensor] = [] + self.fp32_params: List[ColoTensor] = list() + self.fp16_params: List[ColoParameter] = list() self.overflow_counter = 0 - self.grads_device: Dict[torch.Tensor, torch.device] = {} + self.grads_device: Dict[torch.Tensor, torch.device] = dict() + self.param2name: Dict[nn.Parameter, str] = dict() + self.name2param: Dict[str, nn.Parameter] = dict() - cpu_offload = self.gemini_manager.policy_name != 'cuda' + self._cast_buffers() + self._logger = get_dist_logger() if self.gemini_manager._premade_memstats_: # build chunk in param runtime visited order. @@ -234,50 +239,17 @@ class ZeroDDP(ColoDDP): for p in module.parameters(): param_order.append(p) - ddp_pg = ColoProcessGroup() - for p in param_order.generate(): - assert isinstance(p, ColoParameter) + self._init_chunks(param_order=param_order, + strict_ddp_mode=strict_ddp_mode, + cpu_offload=self.gemini_manager.policy_name != 'cuda', + pin_memory=pin_memory) - if strict_ddp_mode: - if not p.is_replicate(): - p.set_dist_spec(ReplicaSpec()) - p.set_process_group(pg=ddp_pg) - - if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=torch.float16) - continue - - fp32_data = p.data.float() - fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) - p.data = p.data.half() - dp_world_size = p.process_group.dp_world_size() - self.chunk_manager.register_tensor(tensor=p, - group_type='fp16_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - self.chunk_manager.register_tensor(tensor=fp32_p, - group_type='fp32_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - self.fp32_params.append(fp32_p) - self.grads_device[p] = self.gemini_manager.default_device - - self.chunk_manager.close_all_groups() - self._cast_buffers() - - params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)] - for p, fp32_p in zip(params_list, self.fp32_params): - chunk_16 = self.chunk_manager.get_chunk(p) - chunk_32 = self.chunk_manager.get_chunk(fp32_p) - chunk_32.init_pair(chunk_16) - - # keep gathered chunks are in CUDA - if chunk_16.keep_gathered: - self.grads_device[p] = get_current_device() - - self._logger = get_dist_logger() + for name, param in module.named_parameters(): + self.param2name[param] = name + for m_name, m_var in module.named_modules(): + for p_name, p_var in m_var.named_parameters(recurse=False): + param_name = m_name + '.' + p_name if m_name else p_name + self.name2param[param_name] = p_var def _post_forward(self): """This function is only triggered for inference. @@ -318,10 +290,23 @@ class ZeroDDP(ColoDDP): continue p.grad = None + def _pre_bacward(self): + # set a visit label for all parameters + # the label is used to check whether the parameter is correctly reduced + for param in self.param2name: + if not is_ddp_ignored(param): + setattr(param, "_gemini_reduced", False) + def _post_backward(self): if self.chunk_manager.accessed_mem != 0: + error_params = ["Reduction failed at followed parameters:"] + for param in self.param2name: + if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): + error_params.append(self.param2name[param]) + error_str = "\n\t".join(error_params) raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", - "The most possible reason is that the model is not compatible with ZeroDDP.") + "The most possible reason is that the model is not compatible with ZeroDDP.\n", + f"{error_str}") self._setup_grads_ptr() self._logger.debug( f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' @@ -329,6 +314,7 @@ class ZeroDDP(ColoDDP): self.gemini_manager.post_iter() def backward(self, loss: torch.Tensor): + self._pre_bacward() with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): loss.backward() self._post_backward() @@ -343,7 +329,9 @@ class ZeroDDP(ColoDDP): free_storage(empty_grad) with torch._C.DisableTorchFunction(): chunk = self.chunk_manager.get_chunk(p) - assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD + if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: + raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + "Some unsupported torch function is operated upon this parameter.") self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) chunk.copy_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(chunk) @@ -367,30 +355,7 @@ class ZeroDDP(ColoDDP): for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True, strict: bool = True): - """ - Args: - strict (bool): whether to reture the whole model state as the pytorch `Module.state_dict()` - - Returns: - dict: - a dictionary containing a whole state of the module - - Example: - - >>> module.state_dict().keys() - ['bias', 'weight'] - """ - if strict: - assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." - torch_model = get_static_torch_model(zero_ddp_model=self, only_rank_0=only_rank_0) - return torch_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - return self._non_strict_state_dict(destination=destination, - prefix=prefix, - keep_vars=keep_vars, - only_rank_0=only_rank_0) - - def _non_strict_state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. @@ -461,19 +426,24 @@ class ZeroDDP(ColoDDP): """ assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + # get copies of fp32 parameters in CPU param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) - ddp_param_list = [] - for name, param in self.named_parameters(): - if is_ddp_ignored(param): - # deal with ddp ignored parameters - destination[prefix + name] = param if keep_vars else param.detach() - else: - ddp_param_list.append((name, param)) - for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): - if p is not None: - assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[fp32_p] - destination[prefix + name] = record_parameter + # get the mapping between copies and fp16 parameters + p_mapping = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + name = self.param2name[p] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + p_mapping[p] = record_parameter + for name, param in self.name2param.items(): + if param is not None: + if is_ddp_ignored(param): + # deal with ddp ignored parameters + destination[prefix + name] = param if keep_vars else param.detach() + else: + destination[prefix + name] = p_mapping[param] + del p_mapping + del param_to_save_data # save all buffers for name, buf in self.named_buffers(): @@ -605,17 +575,15 @@ class ZeroDDP(ColoDDP): def load_fp32_parameter(chunk_slice, data): chunk_slice.copy_(data.flatten()) - ddp_param_list = [] for name, param in self.named_parameters(): if is_ddp_ignored(param): # deal with ddp ignored parameters load(name, param, param.copy_) - else: - ddp_param_list.append((name, param)) fp32_to_name = dict() - for (name, p), fp32_p in zip(ddp_param_list, self.fp32_params): + for p, fp32_p in zip(self.fp16_params, self.fp32_params): if p is not None: + name = self.param2name[p] fp32_to_name[fp32_p] = name chunk_list = self.chunk_manager.get_chunks(self.fp32_params) @@ -662,6 +630,60 @@ class ZeroDDP(ColoDDP): if input_name not in local_state: unexpected_keys.append(key) + def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): + ddp_pg = ColoProcessGroup() + for p in param_order.generate(): + assert isinstance(p, ColoParameter) + + # gather sharded parameters in the strict ddp mode + if strict_ddp_mode: + if not p.is_replicate(): + p.set_dist_spec(ReplicaSpec()) + p.set_process_group(pg=ddp_pg) + + # ignore the parameters with no gradient + if not p.requires_grad: + self.set_params_to_ignore([p]) + + # move ignored parameters to CUDA + if is_ddp_ignored(p): + p.data = p.data.to(device=get_current_device(), dtype=torch.float16) + continue + + # create a fp32 parameter + fp32_data = p.data.float() + fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) + # create a fp16 parameter + p.data = p.data.half() + + # register the fp16 parameter and fp32 parameter in the chunk manager + dp_world_size = p.process_group.dp_world_size() + self.chunk_manager.register_tensor(tensor=p, + group_type='fp16_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.chunk_manager.register_tensor(tensor=fp32_p, + group_type='fp32_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + + self.fp16_params.append(p) + self.fp32_params.append(fp32_p) + self.grads_device[p] = self.gemini_manager.default_device + + self.chunk_manager.close_all_groups() + + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + chunk_32 = self.chunk_manager.get_chunk(fp32_p) + chunk_32.init_pair(chunk_16) + + # keep gathered chunks are in CUDA + if chunk_16.keep_gathered: + self.grads_device[p] = get_current_device() + def _cast_buffers(self): for buffer in self.module.buffers(): buffer.data = buffer.cuda() diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py index 636f1ec74..2c6e15d91 100644 --- a/colossalai/nn/parallel/gemini_parallel.py +++ b/colossalai/nn/parallel/gemini_parallel.py @@ -49,6 +49,10 @@ class GeminiDDP(ZeroDDP): all parameters will be compacted into one small chunk. memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. """ + # some ugly hotfix for the compatibility with Lightning + if search_range_mb is None: + search_range_mb = 32 + chunk_manager = init_chunk_manager(model=module, init_device=device, hidden_dim=hidden_dim, diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py index d323556d5..08fdb6026 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/nn/parallel/utils.py @@ -80,13 +80,11 @@ def get_static_torch_model(zero_ddp_model, from colossalai.nn.parallel import ZeroDDP assert isinstance(zero_ddp_model, ZeroDDP) - state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0, strict=False) + state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) colo_model = zero_ddp_model.module torch_model = _get_shallow_copy_model(colo_model) if not only_rank_0 or dist.get_rank() == 0: - # record the mapping relationship between colo parameters and torch parameters - colo_to_torch = dict() for (name, colo_module), (_, torch_module) in \ zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)): # clean the parameter list of the new torch module @@ -94,17 +92,10 @@ def get_static_torch_model(zero_ddp_model, for sufix_param_name, param in colo_module.named_parameters(recurse=False): # get the full name of the parameter full_param_name = name + ('.' if name else '') + sufix_param_name - - if full_param_name not in state_dict: - # this means the parameter is shared by multiple modules - # we should use colo_to_torch to get the torch parameter created before - assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module" - torch_param = colo_to_torch[param] - else: - # we meet the parameter the first time, just use the state dict to get the data - state_param = state_dict[full_param_name] - torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) - colo_to_torch[param] = torch_param + assert full_param_name in state_dict, \ + f"Can not find parameter `{full_param_name}` in the GeminiDDP module" + state_param = state_dict[full_param_name] + torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype)) setattr(torch_module, sufix_param_name, torch_param) dist.barrier()