From 1aaa45370665ba724a1c815d0dd916b0346e2472 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Thu, 28 Mar 2024 15:02:32 +0800 Subject: [PATCH] perf: use async copy to accelerate memcpy --- colossalai/zero/low_level/low_level_optim.py | 88 +++++++++++++++----- 1 file changed, 66 insertions(+), 22 deletions(-) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 378bbd2fc..9decf6ffd 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -21,7 +21,14 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger -from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor +from ._utils import ( + DataPrefetcher, + calculate_global_norm_from_list, + flatten, + has_inf_or_nan, + release_param_grad, + sync_tensor, +) from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -437,10 +444,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if len(grads) > 0: real_working_params[group_id].append(working_param) grad = grads[grad_index] - # no need to copy fp32 grad if master_weights is False - if self._master_weights: - grad = grad.to(splited_param.dtype).to(splited_param.device) - splited_param.grad = grad grad_partition_groups.append(grad) real_master_params[group_id].append(splited_param) @@ -458,27 +461,68 @@ class LowLevelZeroOptimizer(OptimizerWrapper): global_norm = calculate_global_norm_from_list(norm_list=norm_groups) self._unscale_and_clip_grads(grad_partition_groups, global_norm) - # update the parameters - self.optim.step() - - # release the grad - grad_partition_groups = [] - for group_id in range(self.num_param_groups): - release_param_grad(self._master_param_groups_of_current_rank[group_id]) - - # update working partition updated by the current rank device = get_accelerator().get_current_device() for group_id in range(self.num_param_groups): - master_working_param = self.optim.param_groups[group_id]["params"] - for idx, splited_param in enumerate(master_working_param): - working_param = real_working_params[group_id][idx] - all_splited_param = [ - torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) - ] - dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) - working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) + + def load_grad(num: int): + """copy grads to the same device and dtype as the master weights""" + for i in range(num): + grad = grad_partition_groups.pop(0) + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(real_master_params[group_id][i].dtype).to( + real_master_params[group_id][i].device, non_blocking=True + ) + yield grad + + def load_param(num: int): + """copy params back to the accelerator""" + for i in range(num): + splited_param = real_master_params[group_id][i].to(device, non_blocking=True).to(self._dtype) + yield splited_param + + """ + grad (device) --> grad (host or device) --> optim.step() --> param (host or device) --> param (device) + """ + grad_pre_fetcher, param_pre_fetcher = None, None + for idx in range(len(real_master_params[group_id]) + 1): + is_first_step = idx == 0 + is_last_step = idx == len(real_master_params[group_id]) + + if not is_last_step: + # update the parameters + if grad_pre_fetcher is None: + grad_pre_fetcher = DataPrefetcher(load_grad(len(real_master_params[group_id]))) + + real_master_params[group_id][idx].grad = grad_pre_fetcher.next() + # HACK: torch optim would skip tensor whose grad is None + self.optim.step() + real_master_params[group_id][idx].grad = None + + if not is_first_step: + # update working partition updated by the current rank + if param_pre_fetcher is None: + param_pre_fetcher = DataPrefetcher(load_param(len(real_master_params[group_id]))) + + working_param = real_working_params[group_id][idx - 1] + splited_param = param_pre_fetcher.next() + + all_splited_param = [ + torch.zeros(splited_param.shape, device=device, dtype=self._dtype) + for _ in range(self._world_size) + ] + dist.all_gather(all_splited_param, splited_param, group=self.dp_pg) + + working_param.data.copy_( + flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param) + ) + + # release the grad + release_param_grad(self._master_param_groups_of_current_rank[group_id]) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + assert len(grad_partition_groups) == 0, "grad_partition_groups should be empty after step()" + def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float: r""" Compute and return the gradient norm for gradient clipping.