diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 5bc662a61..333a3f224 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -20,7 +20,12 @@ class ChunkManager: init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. """ - def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None: + def __init__( + self, + chunk_configuration, + init_device: Optional[torch.device] = None, + reuse_fp16_chunk: bool = True, + ) -> None: self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() self.kwargs_config = chunk_configuration @@ -33,6 +38,10 @@ class ChunkManager: self.accessed_chunks: Set[Chunk] = set() self.accessed_mem: int = 0 self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0} + self.reuse_fp16_chunk = reuse_fp16_chunk + # Whether model is accumulating gradients, + self.accumulating_grads = False + self.overflow_counter = 0 def register_tensor( self, diff --git a/colossalai/zero/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py index 7a2ea3606..049c5c102 100644 --- a/colossalai/zero/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -19,6 +19,7 @@ def init_chunk_manager( model: nn.Module, init_device: Optional[torch.device] = None, hidden_dim: Optional[int] = None, + reuse_fp16_chunk: bool = True, verbose: bool = False, **kwargs, ) -> ChunkManager: @@ -50,5 +51,9 @@ def init_chunk_manager( ) dist.barrier() - chunk_manager = ChunkManager(config_dict, init_device) + chunk_manager = ChunkManager( + config_dict, + init_device, + reuse_fp16_chunk=reuse_fp16_chunk, + ) return chunk_manager diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index b25de1d68..c1029097a 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -98,8 +98,14 @@ class GeminiDDP(ModelWrapper): verbose: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) + reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False + self.enable_gradient_accumulation = enable_gradient_accumulation if chunk_config_dict is not None: - self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device) + self.chunk_manager = ChunkManager( + chunk_config_dict, + chunk_init_device, + reuse_fp16_chunk=reuse_fp16_chunk, + ) else: # some ugly hotfix for the compatibility with Lightning if search_range_m is None: @@ -112,6 +118,7 @@ class GeminiDDP(ModelWrapper): min_chunk_size_m=min_chunk_size_m, strict_ddp_flag=strict_ddp_mode, process_group=zero_group, + reuse_fp16_chunk=reuse_fp16_chunk, verbose=verbose, ) self.gemini_manager = GeminiManager( @@ -128,7 +135,6 @@ class GeminiDDP(ModelWrapper): self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() - self.overflow_counter = 0 self.grads_device: Dict[torch.Tensor, torch.device] = dict() self.param2name: Dict[nn.Parameter, str] = dict() self.name2param: Dict[str, nn.Parameter] = dict() @@ -137,14 +143,8 @@ class GeminiDDP(ModelWrapper): self.zero_group = zero_group or _get_default_group() self.extra_dp_group = extra_dp_group - self.reuse_fp16_chunk = master_weights self.master_weights = master_weights - self.enable_gradient_accumulation = enable_gradient_accumulation - if self.enable_gradient_accumulation: - self.reuse_fp16_chunk = False - self.accumulating_grads = False # Whether model is accumulating gradients - self._logger = get_dist_logger() if self.gemini_manager._premade_memstats_: @@ -178,7 +178,29 @@ class GeminiDDP(ModelWrapper): if is_ddp_ignored(p): continue if p.requires_grad: - p.register_hook(partial(self.grad_handle, p)) + p._grad_handle = p.register_hook( + partial( + GeminiDDP.grad_handle, + chunk_manager=self.chunk_manager, + param2name=self.param2name, + grads_device=self.grads_device, + master_weights=self.master_weights, + enable_gradient_accumulation=self.enable_gradient_accumulation, + p=p, + ) + ) + + def remove_hooks(self): + for p in self.module.parameters(): + if is_ddp_ignored(p): + continue + if p.requires_grad: + assert hasattr(p, "_grad_handle") + p._grad_handle.remove() + delattr(p, "_grad_handle") + + def __del__(self): + self.remove_hooks() def parameters(self, recurse: bool = True): return self.module.parameters(recurse) @@ -324,8 +346,8 @@ class GeminiDDP(ModelWrapper): f"{error_str}", ) self._setup_grads_ptr() - if self.enable_gradient_accumulation and not self.accumulating_grads: - self.accumulating_grads = True # Turn on the state of gradient accumulation. + if self.enable_gradient_accumulation and not self.chunk_manager.accumulating_grads: + self.chunk_manager.accumulating_grads = True # Turn on the state of gradient accumulation. self._logger.debug( f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}" ) @@ -340,25 +362,34 @@ class GeminiDDP(ModelWrapper): def backward_by_grad(self, tensor, grad): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") - def grad_handle(self, p, grad): + @staticmethod + def grad_handle( + grad, + chunk_manager: ChunkManager, + param2name: Dict, + grads_device: Dict, + master_weights: bool, + enable_gradient_accumulation: bool, + p: nn.Parameter, + ): setattr(p, "_gemini_reduced", True) empty_grad = torch.empty_like(grad) free_storage(empty_grad) with torch._C.DisableTorchFunction(): - chunk = self.chunk_manager.get_chunk(p) + chunk = chunk_manager.get_chunk(p) if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: raise RuntimeError( - f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + f"Parameter `{param2name[p]}` failed at the gradient reduction. " "Some unsupported torch function is operated upon this parameter." ) grad_chunk = chunk - if not self.reuse_fp16_chunk: - if not self.accumulating_grads: - grad_chunk = self.chunk_manager.init_grad_chunk(chunk) + if not chunk_manager.reuse_fp16_chunk: + if not chunk_manager.accumulating_grads: + grad_chunk = chunk_manager.init_grad_chunk(chunk) else: assert chunk.grad_chunk is not None - if chunk.grad_chunk not in self.chunk_manager.accessed_chunks: - grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk) + if chunk.grad_chunk not in chunk_manager.accessed_chunks: + grad_chunk = chunk_manager.rearrange_accumulated_grad_chunk(chunk) else: grad_chunk = chunk.grad_chunk chunk.grad_chunk.l2_norm = None @@ -371,33 +402,33 @@ class GeminiDDP(ModelWrapper): chunk.tensor_trans_state(p, TensorState.HOLD) grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE) - if not self.accumulating_grads: - grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk) + if not chunk_manager.accumulating_grads: + grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk) else: grad_chunk.add_tensor_to_chunk_slice(p, grad) - reduced = self.chunk_manager.reduce_chunk(grad_chunk) + reduced = chunk_manager.reduce_chunk(grad_chunk) if reduced: - if not self.reuse_fp16_chunk: + if not chunk_manager.reuse_fp16_chunk: if chunk.keep_gathered: - self.chunk_manager.fake_release_chunk(chunk) + chunk_manager.fake_release_chunk(chunk) else: - self.chunk_manager.release_chunk(chunk) + chunk_manager.release_chunk(chunk) if grad_chunk.is_gathered: grad_chunk.cuda_global_chunk.div_(chunk.pg_size) - if self.extra_dp_group is not None: + if chunk.extra_dp_group is not None: grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size) else: grad_chunk.cuda_shard.div_(chunk.pg_size) - if self.extra_dp_group is not None: + if chunk.extra_dp_group is not None: grad_chunk.cuda_shard.div_(chunk.extra_dp_size) # check overflow elements - self.overflow_counter += grad_chunk.has_inf_or_nan + chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan # record l2 norm for gradient clipping. flag is bound to fp16 chunk if chunk.l2_norm_flag: grad_chunk.set_l2_norm() - self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True) - if not (self.master_weights) or (self.enable_gradient_accumulation): - self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) + if not (master_weights) or (enable_gradient_accumulation): + chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: @@ -513,11 +544,11 @@ class GeminiDDP(ModelWrapper): # get copies of fp32 parameters in CPU # as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16 - params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params + params = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params param_to_save_data = self._get_param_to_save_data(params, only_rank_0) # get the mapping between copies and fp16 parameters p_mapping = dict() - if self.reuse_fp16_chunk: + if self.chunk_manager.reuse_fp16_chunk: for p, fp32_p in zip(self.fp16_params, self.fp32_params): name = self.param2name[p] assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) @@ -713,7 +744,7 @@ class GeminiDDP(ModelWrapper): name = self.param2name[p] fp32_to_name[fp32_p] = name - params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params + params_to_load = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params chunk_list = self.chunk_manager.get_chunks(params_to_load) for chunk in chunk_list: temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) @@ -728,7 +759,9 @@ class GeminiDDP(ModelWrapper): shard_fn = tensor.shard_fn gather_fn = tensor.gather_fn - parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] + parameter_name = ( + fp32_to_name[tensor] if self.chunk_manager.reuse_fp16_chunk else self.param2name[tensor] + ) parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] load( parameter_name, @@ -900,7 +933,7 @@ class GeminiDDP(ModelWrapper): gathered_param = param if keep_vars else param.detach() else: # as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16 - param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param + param_to_save = fp16_to_fp32[param] if self.chunk_manager.reuse_fp16_chunk else param if param_to_save not in gathered_param_buffer: chunk = self.chunk_manager.get_chunk(param_to_save) gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index ae02fe297..18918eabc 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): self.module = module def check_local_overflow(self) -> bool: - return self.module.overflow_counter > 0 + return self.module.chunk_manager.overflow_counter > 0 def pre_zero_grad(self) -> None: - self.module.overflow_counter = 0 + self.module.chunk_manager.overflow_counter = 0 class GeminiOptimizer(OptimizerWrapper): @@ -202,7 +202,7 @@ class GeminiOptimizer(OptimizerWrapper): chunk16 = self.param_to_chunk16[fake_param] begin, end = self.param_to_range[fake_param] - grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk + grad_chunk16 = chunk16 if self.module.chunk_manager.reuse_fp16_chunk else chunk16.grad_chunk fake_param.data = grad_chunk16.payload[begin:end] fake_param.grad = fake_param.data @@ -221,14 +221,14 @@ class GeminiOptimizer(OptimizerWrapper): def _clear_global_norm(self) -> None: for c16 in self.chunk16_set: - grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk + grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk grad_chunk.l2_norm = None def _calc_global_norm(self) -> float: norm_sqr: float = 0.0 group_to_norm = dict() for c16 in self.chunk16_set: - grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk + grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk assert grad_chunk.l2_norm is not None if grad_chunk.is_gathered: @@ -275,7 +275,7 @@ class GeminiOptimizer(OptimizerWrapper): self._logger.info(f"Found overflow. Skip step") self._clear_global_norm() # clear recorded norm self.zero_grad() # reset all gradients - if self.module.reuse_fp16_chunk: + if self.module.chunk_manager.reuse_fp16_chunk: self._update_fp16_params() return @@ -288,7 +288,7 @@ class GeminiOptimizer(OptimizerWrapper): self.zero_grad() if self.module.master_weights: self._update_fp16_params() - self.module.accumulating_grads = False + self.module.chunk_manager.accumulating_grads = False return ret def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index d9084fd5a..570a0aa42 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -26,7 +26,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): chunk_manager = model.chunk_manager param_list = [p for p in model.parameters()] chunk_list = chunk_manager.get_chunks(param_list) - if not model.reuse_fp16_chunk: + if not model.chunk_manager.reuse_fp16_chunk: chunk_list = [chunk.grad_chunk for chunk in chunk_list] for chunk in chunk_list: chunk_manager.access_chunk(chunk)