From df63564184d5b13ab5f6edbca0c5e9abe211ef95 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 12 Oct 2023 10:39:08 +0800 Subject: [PATCH] [gemini] support amp o3 for gemini (#4872) * [gemini] support no reuse fp16 chunk * [gemini] support no master weight for optim * [gemini] support no master weight for gemini ddp * [test] update gemini tests * [test] update gemini tests * [plugin] update gemini plugin * [test] fix gemini checkpointio test * [test] fix gemini checkpoint io --- colossalai/booster/plugin/gemini_plugin.py | 5 +- colossalai/nn/optimizer/cpu_adam.py | 6 +- colossalai/nn/optimizer/hybrid_adam.py | 6 +- colossalai/testing/comparison.py | 6 +- colossalai/zero/gemini/chunk/chunk.py | 52 ++++++- colossalai/zero/gemini/chunk/manager.py | 10 ++ colossalai/zero/gemini/gemini_ddp.py | 140 ++++++++++-------- colossalai/zero/gemini/gemini_optimizer.py | 52 ++++--- colossalai/zero/gemini/utils.py | 6 +- .../test_gemini_checkpoint_io.py | 10 +- .../test_gemini_torch_compability.py | 6 +- tests/test_zero/test_gemini/test_fwd_bwd.py | 14 +- tests/test_zero/test_gemini/test_grad_clip.py | 6 +- tests/test_zero/test_gemini/test_optim.py | 2 +- .../test_gemini/test_zeroddp_state_dict.py | 15 +- 15 files changed, 222 insertions(+), 114 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ca722a076..6c1658575 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -97,7 +97,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) @@ -257,6 +257,7 @@ class GeminiPlugin(DPPluginBase): warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. + master_weights (bool, optional): master weights. Defaults to True. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. @@ -296,6 +297,7 @@ class GeminiPlugin(DPPluginBase): warmup_non_model_data_ratio: float = 0.8, # only for auto placement steady_cuda_cap_ratio: float = 0.9, # only for auto placement precision: str = "fp16", + master_weights: bool = True, pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, @@ -334,6 +336,7 @@ class GeminiPlugin(DPPluginBase): min_chunk_size_m=min_chunk_size_m, memstats=memstats, mixed_precision=PRECISION_STR_TO_DTYPE[precision], + master_weights=master_weights, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 3dc729c32..1bdb81e2d 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -132,9 +132,6 @@ class CPUAdam(NVMeOptimizer): target_device = p.device if len(state) == 0: state["step"] = 0 - - # FIXME(ver217): CPU adam kernel only supports fp32 states now - assert p.dtype is torch.float, "CPUAdam only support fp32 parameters" # gradient momentums state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances @@ -149,7 +146,8 @@ class CPUAdam(NVMeOptimizer): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - if p.grad.dtype is torch.bfloat16: + # FIXME(ver217): CPU adam kernel only supports fp32 states now + if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 32fc6136c..7dc4590dc 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -108,9 +108,6 @@ class HybridAdam(CPUAdam): target_device = p.device if len(state) == 0: state["step"] = 0 - - # FIXME(ver217): CPU adam kernel only supports fp32 states now - assert p.dtype is torch.float, "HybridAdam only support fp32 parameters" # gradient momentums state["exp_avg"] = torch.zeros_like(p, device=target_device) # gradient variances @@ -125,7 +122,8 @@ class HybridAdam(CPUAdam): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - if p.grad.dtype is torch.bfloat16: + # FIXME(ver217): CPU adam kernel only supports fp32 states now + if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 816bc0d7b..4f2a4878e 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -40,7 +40,7 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" -def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): +def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False): assert len(list(d1.keys())) == len( list(d2.keys()) ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" @@ -58,6 +58,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool if not ignore_device: v1_i = v1_i.to("cpu") v2_i = v2_i.to("cpu") + if ignore_dtype: + v1_i = v1_i.to(v2_i.dtype) assert_close_loose(v1_i, v2_i) elif isinstance(v1_i, dict): assert isinstance(v2_i, dict) @@ -69,6 +71,8 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool if not ignore_device: v1 = v1.to("cpu") v2 = v2.to("cpu") + if ignore_dtype: + v1 = v1.to(v2.dtype) assert_close_loose(v1, v2) else: assert v1 == v2, f"{v1} not equals to {v2}" diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index bbef9013c..c8be773b2 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -160,6 +160,8 @@ class Chunk: self.l2_norm_flag = False self.l2_norm = None + self.grad_chunk = None + @property def memory_usage(self) -> Dict[str, int]: cuda_memory = 0 @@ -414,7 +416,9 @@ class Chunk: return self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) - def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + def copy_tensor_to_chunk_slice( + self, tensor: torch.Tensor, data_slice: torch.Tensor, update_ptr: bool = True + ) -> None: """ Copy data slice to the memory space indexed by the input tensor in the chunk. @@ -427,7 +431,8 @@ class Chunk: tensor_info = self.tensors_info[tensor] self.cuda_global_chunk[tensor_info.offset : tensor_info.end].copy_(data_slice.data.flatten()) - tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) + if update_ptr: + tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) def get_valid_length(self) -> int: """Get the valid length of the chunk's payload.""" @@ -577,3 +582,46 @@ class Chunk: output.append("\t\t# of {}: {}\n".format(st, self.tensor_state_cnter[st])) return "".join(output) + + def init_grad_chunk(self) -> "Chunk": + """Init grad chunk. This should be called in grad handler. + + Returns: + Chunk: Grad chunk + """ + if self.grad_chunk is None: + # grad chunk is not initialized + grad_chunk = Chunk( + chunk_size=self.chunk_size, + process_group=self.torch_pg, + dtype=self.dtype, + keep_gathered=self.keep_gathered, + pin_memory=self.pin_memory, + ) + grad_chunk.num_tensors = self.num_tensors + grad_chunk.utilized_size = self.utilized_size + grad_chunk.tensor_state_cnter[TensorState.HOLD] = self.num_tensors + for tensor, state in self.tensors_info.items(): + grad_chunk.tensors_info[tensor] = TensorInfo(TensorState.HOLD, state.offset, state.end) + + grad_chunk.valid_end = self.valid_end + + if grad_chunk.chunk_temp.device.type == "cpu": + grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device()) + else: + grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp + grad_chunk.chunk_temp = None + + if grad_chunk.pin_memory: + grad_chunk.cpu_shard = torch.empty( + grad_chunk.shard_size, dtype=grad_chunk.dtype, pin_memory=grad_chunk.pin_memory + ) + + self.grad_chunk = grad_chunk + else: + # grad chunk is initialized, just reallocate cuda global chunk + self.grad_chunk.cuda_shard = None + self.grad_chunk.is_gathered = True + alloc_storage(self.grad_chunk.cuda_global_chunk) + + return self.grad_chunk diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 957e41b02..713c11742 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -245,3 +245,13 @@ class ChunkManager: chunk.release_chunk() self.accessed_chunks.remove(chunk) self.accessed_mem -= chunk.chunk_mem + + def init_grad_chunk(self, chunk: Chunk) -> Chunk: + if chunk.grad_chunk is not None: + self.__sub_memory_usage(chunk.grad_chunk.memory_usage) + grad_chunk = chunk.init_grad_chunk() + self.__add_memory_usage(grad_chunk.memory_usage) + if grad_chunk not in self.accessed_chunks: + self.accessed_chunks.add(grad_chunk) + self.accessed_mem += grad_chunk.chunk_mem + return grad_chunk diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 0ba9e53cf..a4871f7e4 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -74,6 +74,7 @@ class GeminiDDP(ModelWrapper): mixed_precision: torch.dtype = torch.float16, process_group: Optional[ProcessGroup] = None, memstats: Optional[MemStats] = None, # genimi memory stats + master_weights: bool = True, verbose: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) @@ -115,6 +116,9 @@ class GeminiDDP(ModelWrapper): self.mixed_precision = mixed_precision self.dp_process_group = process_group or _get_default_group() + self.reuse_fp16_chunk = master_weights + self.master_weights = master_weights + self._logger = get_dist_logger() if self.gemini_manager._premade_memstats_: @@ -321,20 +325,37 @@ class GeminiDDP(ModelWrapper): f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " "Some unsupported torch function is operated upon this parameter." ) - self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) - chunk.copy_tensor_to_chunk_slice(p, grad) - reduced = self.chunk_manager.reduce_chunk(chunk) + grad_chunk = chunk + if not self.reuse_fp16_chunk: + grad_chunk = self.chunk_manager.init_grad_chunk(chunk) + # hold -> compute -> hold after bwd + grad_chunk.tensor_trans_state(p, TensorState.COMPUTE) + grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD) + # fp16 param chunk: hold after bwd -> ready for reduce -> hold + chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE) + chunk.tensor_trans_state(p, TensorState.HOLD) + + grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE) + grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk) + reduced = self.chunk_manager.reduce_chunk(grad_chunk) if reduced: - if chunk.is_gathered: - chunk.cuda_global_chunk.div_(chunk.pg_size) + if not self.reuse_fp16_chunk: + if chunk.keep_gathered: + self.chunk_manager.fake_release_chunk(chunk) + else: + self.chunk_manager.release_chunk(chunk) + if grad_chunk.is_gathered: + grad_chunk.cuda_global_chunk.div_(chunk.pg_size) else: - chunk.cuda_shard.div_(chunk.pg_size) + grad_chunk.cuda_shard.div_(chunk.pg_size) # check overflow elements - self.overflow_counter += chunk.has_inf_or_nan - # record l2 norm for gradient clipping + self.overflow_counter += grad_chunk.has_inf_or_nan + # record l2 norm for gradient clipping. flag is bound to fp16 chunk if chunk.l2_norm_flag: - chunk.set_l2_norm() - self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + grad_chunk.set_l2_norm() + self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True) + if not self.master_weights: + self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: @@ -344,9 +365,7 @@ class GeminiDDP(ModelWrapper): for tensor in chunk.get_tensors(): self.grads_device[tensor] = device - def state_dict( - self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True, dtype: torch.dtype = torch.float16 - ): + def state_dict(self, destination=None, prefix="", keep_vars=False, only_rank_0: bool = True): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. @@ -365,7 +384,7 @@ class GeminiDDP(ModelWrapper): destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0, dtype) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) for hook in self._state_dict_hooks.values(): hook_result = hook(self, destination, prefix, local_metadata) @@ -373,7 +392,7 @@ class GeminiDDP(ModelWrapper): destination = hook_result return destination - def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool, dtype: torch.dtype = torch.float16) -> Dict: + def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: """ get gathered chunk content. @@ -386,9 +405,8 @@ class GeminiDDP(ModelWrapper): """ # save parameters chunk_to_save_data = dict() - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - if torch.is_floating_point(temp_chunk): - temp_chunk = temp_chunk.to(dtype) + temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) + for tensor, tensor_info in chunk.tensors_info.items(): record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) @@ -401,9 +419,7 @@ class GeminiDDP(ModelWrapper): del temp_chunk return chunk_to_save_data - def _get_param_to_save_data( - self, param_list: List[torch.nn.Parameter], only_rank_0: bool, dtype: torch.dtype - ) -> Dict: + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: """ get param content from chunks. @@ -418,10 +434,10 @@ class GeminiDDP(ModelWrapper): param_to_save_data = dict() chunk_list = self.chunk_manager.get_chunks(param_list) for chunk in chunk_list: - param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) + param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0)) return param_to_save_data - def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True, dtype=torch.float16): + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): r"""Saves module state to `destination` dictionary, containing a state of the module, but not its descendants. This is called on every submodule in :meth:`~torch.nn.Module.state_dict`. @@ -438,14 +454,18 @@ class GeminiDDP(ModelWrapper): # get copies of fp32 parameters in CPU # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16 - param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0, dtype) + params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params + param_to_save_data = self._get_param_to_save_data(params, only_rank_0) # get the mapping between copies and fp16 parameters p_mapping = dict() - for p, fp32_p in zip(self.fp16_params, self.fp32_params): - name = self.param2name[p] - assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[fp32_p] - p_mapping[p] = record_parameter + if self.reuse_fp16_chunk: + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + name = self.param2name[p] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + p_mapping[p] = record_parameter + else: + p_mapping = param_to_save_data for name, param in self.name2param.items(): if param is not None: if is_ddp_ignored(param): @@ -593,7 +613,7 @@ class GeminiDDP(ModelWrapper): elif strict: missing_keys.append(state_key) - def load_fp32_parameter(chunk_slice, data): + def load_parameter(chunk_slice, data): chunk_slice.copy_(data.flatten()) for name, param in self.named_parameters(): @@ -607,14 +627,15 @@ class GeminiDDP(ModelWrapper): name = self.param2name[p] fp32_to_name[fp32_p] = name - chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params + chunk_list = self.chunk_manager.get_chunks(params_to_load) for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) + temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) for tensor, tensor_info in chunk.tensors_info.items(): - parameter_name = fp32_to_name[tensor] + parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] - load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) + load(parameter_name, tensor, partial(load_parameter, parameter_slice)) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) @@ -624,11 +645,11 @@ class GeminiDDP(ModelWrapper): chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin : chunk.shard_end]) del temp_chunk - - for chunk_32 in chunk_list: - chunk_16 = chunk_32.paired_chunk - assert chunk_16 is not None - chunk_16.payload.copy_(chunk_32.payload) + if self.reuse_fp16_chunk: + for chunk_32 in chunk_list: + chunk_16 = chunk_32.paired_chunk + assert chunk_16 is not None + chunk_16.payload.copy_(chunk_32.payload) for name, buf in persistent_buffers.items(): if buf is not None: @@ -668,12 +689,9 @@ class GeminiDDP(ModelWrapper): p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) continue - # create a fp32 parameter - fp32_p = p.data.float() # create a fp16 parameter p.data = p.data.to(self.mixed_precision) - - # register the fp16 parameter and fp32 parameter in the chunk manager + # register the fp16 parameter self.chunk_manager.register_tensor( tensor=p, group_type="fp16_param", @@ -682,22 +700,27 @@ class GeminiDDP(ModelWrapper): cpu_offload=cpu_offload, pin_memory=pin_memory, ) - self.chunk_manager.register_tensor( - tensor=fp32_p, - group_type="fp32_param", - config_key=dp_world_size, - process_group=self.dp_process_group, - cpu_offload=cpu_offload, - pin_memory=pin_memory, - ) - self.fp16_params.append(p) - self.fp32_params.append(fp32_p) + + if self.master_weights: + # create a fp32 parameter + fp32_p = p.data.float() + self.chunk_manager.register_tensor( + tensor=fp32_p, + group_type="fp32_param", + config_key=dp_world_size, + process_group=self.dp_process_group, + cpu_offload=cpu_offload, + pin_memory=pin_memory, + ) + self.fp32_params.append(fp32_p) self.chunk_manager.close_all_groups() self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device) + # move master weights to corresponding device and setup paired chunks + # if no master weights, fp32_params should be empty and this loop will be skipped for p, fp32_p in zip(self.fp16_params, self.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) chunk_32 = self.chunk_manager.get_chunk(fp32_p) @@ -734,7 +757,6 @@ class GeminiDDP(ModelWrapper): keep_vars: bool = False, max_shard_size: int = 1024, only_rank_0: bool = True, - dtype: torch.dtype = torch.float16, ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. @@ -769,11 +791,11 @@ class GeminiDDP(ModelWrapper): gathered_param = param if keep_vars else param.detach() else: # as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16 - fp32_param = fp16_to_fp32[param] - if fp32_param not in gathered_param_buffer: - chunk = self.chunk_manager.get_chunk(fp32_param) - gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) - gathered_param = gathered_param_buffer.pop(fp32_param) + param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param + if param_to_save not in gathered_param_buffer: + chunk = self.chunk_manager.get_chunk(param_to_save) + gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) + gathered_param = gathered_param_buffer.pop(param_to_save) block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 1aece9954..3c42e96cb 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -105,7 +105,7 @@ class GeminiOptimizer(OptimizerWrapper): self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() - self.param_to_chunk32: Dict[Parameter, Chunk] = dict() + self.param_to_chunk16: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() self.clipping_flag = max_norm > 0.0 self.max_norm = max_norm @@ -130,7 +130,7 @@ class GeminiOptimizer(OptimizerWrapper): else: ddp_param_list.append(param) - for p, fp32_p in zip(ddp_param_list, module.fp32_params): + for p in ddp_param_list: chunk_16 = self.chunk_manager.get_chunk(p) if chunk_16 not in self.chunk16_set: chunk_16.l2_norm_flag = self.clipping_flag @@ -174,13 +174,15 @@ class GeminiOptimizer(OptimizerWrapper): def _set_grad_ptr(self): for group in self.param_groups: for fake_param in group["params"]: - chunk32 = self.param_to_chunk32[fake_param] + chunk16 = self.param_to_chunk16[fake_param] begin, end = self.param_to_range[fake_param] - chunk16 = chunk32.paired_chunk - fake_param.data = chunk16.payload[begin:end] + grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk + fake_param.data = grad_chunk16.payload[begin:end] fake_param.grad = fake_param.data - fake_param.data = chunk32.payload[begin:end] + + to_update_chunk = chunk16.paired_chunk if self.module.master_weights else chunk16 + fake_param.data = to_update_chunk.payload[begin:end] def _update_fp16_params(self): none_tensor = torch.empty([0]) @@ -194,23 +196,25 @@ class GeminiOptimizer(OptimizerWrapper): def _clear_global_norm(self) -> None: for c16 in self.chunk16_set: - c16.l2_norm = None + grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk + grad_chunk.l2_norm = None def _calc_global_norm(self) -> float: norm_sqr: float = 0.0 group_to_norm = dict() for c16 in self.chunk16_set: - assert c16.l2_norm is not None + grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk + assert grad_chunk.l2_norm is not None - if c16.is_gathered: - norm_sqr += c16.l2_norm + if grad_chunk.is_gathered: + norm_sqr += grad_chunk.l2_norm else: # this chunk is sharded, use communication to collect total norm - if c16.torch_pg not in group_to_norm: - group_to_norm[c16.torch_pg] = 0.0 - group_to_norm[c16.torch_pg] += c16.l2_norm + if grad_chunk.torch_pg not in group_to_norm: + group_to_norm[grad_chunk.torch_pg] = 0.0 + group_to_norm[grad_chunk.torch_pg] += grad_chunk.l2_norm - c16.l2_norm = None # clear l2 norm + grad_chunk.l2_norm = None # clear l2 norm comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device()) for group, part_norm in group_to_norm.items(): @@ -237,7 +241,8 @@ class GeminiOptimizer(OptimizerWrapper): return self.optim.zero_grad(set_to_none=True) def step(self, *args, **kwargs): - self._maybe_move_fp32_params() + if self.module.master_weights: + self._maybe_move_fp32_params() self._set_grad_ptr() if self.mix_precision_mixin.should_skip_step(): @@ -245,7 +250,8 @@ class GeminiOptimizer(OptimizerWrapper): self._logger.info(f"Found overflow. Skip step") self._clear_global_norm() # clear recorded norm self.zero_grad() # reset all gradients - self._update_fp16_params() + if self.module.reuse_fp16_chunk: + self._update_fp16_params() return # get combined scale. combined scale = loss scale * clipping norm @@ -255,7 +261,8 @@ class GeminiOptimizer(OptimizerWrapper): ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) self._register_states() self.zero_grad() - self._update_fp16_params() + if self.module.master_weights: + self._update_fp16_params() return ret def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): @@ -282,8 +289,8 @@ class GeminiOptimizer(OptimizerWrapper): for group in self.param_groups: for fake_param in group["params"]: - chunk32 = self.param_to_chunk32[fake_param] - chunk16 = chunk32.paired_chunk + chunk16 = self.param_to_chunk16[fake_param] + chunk32 = chunk16.paired_chunk if chunk32.device_type == "cuda": continue @@ -297,7 +304,8 @@ class GeminiOptimizer(OptimizerWrapper): for group in self.param_groups: for fake_param in group["params"]: - chunk32 = self.param_to_chunk32[fake_param] + chunk16 = self.param_to_chunk16[fake_param] + chunk32 = chunk16.paired_chunk if chunk32.device_type == "cuda": state = self.optim.state[fake_param] for k, v in state.items(): @@ -341,7 +349,7 @@ class GeminiOptimizer(OptimizerWrapper): continue grad_device = self.module.grads_device[param] fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) - self.param_to_chunk32[fake_param] = chunk16.paired_chunk + self.param_to_chunk16[fake_param] = chunk16 self.param_to_range[fake_param] = range_pair self.id_to_fake_params[param_id] = fake_param fake_params_list.append(fake_param) @@ -366,7 +374,7 @@ class GeminiOptimizer(OptimizerWrapper): if param_id not in self.id_to_fake_params: return -1, -1, -1 fake_param = self.id_to_fake_params[param_id] - chunk = self.param_to_chunk32[fake_param].paired_chunk + chunk = self.param_to_chunk16[fake_param] param = self.id_to_real_params[param_id] param_info = chunk.tensors_info[param] diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py index 264099d22..5305953fe 100644 --- a/colossalai/zero/gemini/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -11,7 +11,7 @@ from colossalai.utils import get_current_device from .chunk import Chunk -def get_temp_total_chunk_on_cuda(chunk: Chunk): +def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype): if chunk.is_gathered: return chunk.cuda_global_chunk @@ -20,7 +20,9 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk): else: shard_temp = chunk.cpu_shard.to(get_current_device()) - total_temp = torch.zeros(chunk.chunk_size, dtype=chunk.dtype, device=get_current_device()) + shard_temp = shard_temp.to(dtype) + + total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device()) gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0)) dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 634e81bb2..f87604038 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -58,9 +58,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b dist.barrier() new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - check_state_dict_equal( - bert_model.state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False - ) + check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False) @clear_cache_before_run() @@ -100,7 +98,9 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha dist.barrier() booster.load_model(new_model, model_ckpt_path) - check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False) + check_state_dict_equal( + model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True + ) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) check_state_dict_equal( @@ -136,7 +136,7 @@ def exam_lazy_from_pretrained(): booster.save_model(model, save_path, shard=False) dist.barrier() state_dict = torch.load(save_path, map_location="cpu") - check_state_dict_equal(state_dict, orig_state_dict, False) + check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True) def run_dist(rank, world_size, port): diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index d46e5380d..bb7a60035 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -60,9 +60,10 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): # Add prefix to get aligned with pytorch parameter names. check_state_dict_equal( - model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + model.state_dict(only_rank_0=False, prefix="module.module."), new_model.state_dict(), False, + ignore_dtype=True, ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) @@ -125,9 +126,10 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): # Add prefix to get aligned with pytorch parameter names. check_state_dict_equal( - new_model.state_dict(only_rank_0=False, prefix="module.module.", dtype=torch.float32), + new_model.state_dict(only_rank_0=False, prefix="module.module."), model.state_dict(), False, + ignore_dtype=True, ) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 94e700400..2fb2bcbc8 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -27,6 +27,8 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): chunk_manager = model.chunk_manager param_list = [p for p in model.parameters()] chunk_list = chunk_manager.get_chunks(param_list) + if not model.reuse_fp16_chunk: + chunk_list = [chunk.grad_chunk for chunk in chunk_list] for chunk in chunk_list: chunk_manager.access_chunk(chunk) @@ -36,13 +38,15 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gather", [False, True]) -@parameterize("model_name", ["gpt2", "bert", "albert"]) +@parameterize("model_name", ["gpt2", "bert"]) @parameterize("use_grad_checkpoint", [False, True]) +@parameterize("master_weights", [False, True]) def exam_gpt_fwd_bwd( placement_config, keep_gather, model_name: str, use_grad_checkpoint: bool = False, + master_weights: bool = True, ): init_device = get_current_device() get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -60,12 +64,14 @@ def exam_gpt_fwd_bwd( config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gather - model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config) + model = GeminiDDP( + model, config_dict, init_device, pin_memory=True, **placement_config, master_weights=master_weights + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) rank = dist.get_rank() - amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1) + amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, master_weights=master_weights) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[rank]) @@ -106,4 +112,4 @@ def test_gpt(world_size): if __name__ == "__main__": - test_gpt(4) + test_gpt(1) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index d8bcc555a..a3af81646 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -78,7 +78,11 @@ def exam_grad_clipping(placement_config, model_name: str): init_device = None model = GeminiDDP( - model, chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, **placement_config + model, + chunk_config_dict=config_dict, + chunk_init_device=init_device, + pin_memory=True, + **placement_config, ) optimizer = HybridAdam(model.parameters(), lr=1e-3) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index b7c083926..8e8e508ff 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -44,7 +44,7 @@ BF16_IGNORED_KEYS = [ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): - zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) + zero_dict = model.state_dict(only_rank_0=False) torch_dict = torch_model.state_dict() for key, value in torch_dict.items(): diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 3130440bd..bf16a301c 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -27,7 +27,8 @@ def ignore_the_first_parameter(model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) @parameterize("model_name", ["gpt2", "bert"]) -def exam_state_dict(placement_config, keep_gathered, model_name: str): +@parameterize("master_weights", [False, True]) +def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -42,7 +43,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) model.train() zero_dict = model.state_dict(only_rank_0=False) @@ -57,7 +58,8 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) @parameterize("model_name", ["gpt2", "bert"]) -def exam_load_state_dict(placement_config, keep_gathered, model_name: str): +@parameterize("master_weights", [False, True]) +def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -72,7 +74,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = keep_gathered - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) torch_dict = torch_model.state_dict() model.load_state_dict(torch_dict, strict=False) @@ -86,7 +88,8 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["gpt2", "bert"]) -def exam_state_dict_shard(placement_config, model_name: str): +@parameterize("master_weights", [False, True]) +def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -95,7 +98,7 @@ def exam_state_dict_shard(placement_config, model_name: str): model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - model = GeminiDDP(model, config_dict, **placement_config) + model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights) model.train() zero_dict = model.state_dict(only_rank_0=False)