From e27645376de9e65fd9e79e7c5bab240dd47ebef5 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Mon, 27 Jun 2022 09:53:57 +0800 Subject: [PATCH] [hotfix]different overflow status lead to communication stuck. (#1175) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [hotfix]fix some bugs caused by refactored schedule. * [hotfix]different overflow statu llead to communication stuck. --- colossalai/amp/naive_amp/_fp16_optimizer.py | 17 +++++++++-------- colossalai/communication/p2p.py | 16 +++++++++++----- colossalai/utils/common.py | 18 +++++++++++++++--- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/colossalai/amp/naive_amp/_fp16_optimizer.py b/colossalai/amp/naive_amp/_fp16_optimizer.py index e47544dce..58d9e3df1 100644 --- a/colossalai/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/amp/naive_amp/_fp16_optimizer.py @@ -258,24 +258,25 @@ class FP16Optimizer(Optimizer): overflow = self._check_overflow() self._grad_scaler.update(overflow) - if overflow: self.zero_grad() - return False, None # Clip the main gradients. grad_norm = None if self._clip_grad_max_norm > 0.0: grad_norm = self.clip_grad_norm(self._clip_grad_max_norm) - # Step the optimizer. - self._optimizer.step() + if not overflow: + # Step the optimizer. + self._optimizer.step() - # Update params from main params. - self._update_fp16_param_from_fp32_param() + # Update params from main params. + self._update_fp16_param_from_fp32_param() - # Successful update. - return True, grad_norm + # Successful update. + return True, grad_norm + else: + return False, None def backward(self, loss): """Execute backward pass. diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py index 6722860e7..3301b750f 100644 --- a/colossalai/communication/p2p.py +++ b/colossalai/communication/p2p.py @@ -57,10 +57,14 @@ def process_object_to_send(object_send, scatter_gather_tensors): if send_split: object_send = split_tensor_into_1d_equal_chunks(object_send) return object_send + + object_send_list = [] for tensor_send in object_send: send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1] if send_split: - tensor_send = split_tensor_into_1d_equal_chunks(tensor_send) + object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send)) + object_send = tuple(object_send_list) + return object_send @@ -161,15 +165,17 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non if isinstance(tensor_recv_prev, torch.Tensor): tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() else: - for tensor_recv, tensor_shape in zip(tensor_recv_prev, recv_prev_shape): - tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_() + for index in range(len(tensor_recv_prev)): + tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view( + recv_prev_shape[index]).requires_grad_() if recv_next and recv_next_split: if isinstance(tensor_recv_next, torch.Tensor): tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() else: - for tensor_recv, tensor_shape in zip(tensor_recv_next, recv_next_shape): - tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_() + for index in range(len(tensor_recv_next)): + tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view( + recv_next_shape[index]).requires_grad_() return tensor_recv_prev, tensor_recv_next diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index 988dfc90b..6114ab11a 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -151,6 +151,14 @@ def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.T return norm +def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor: + if isinstance(norm, float): + norm = torch.Tensor([norm]) + if move_to_cuda: + norm = norm.to(torch.cuda.current_device()) + return norm + + # ======== Gradient Clipping ========= @@ -192,14 +200,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): params.append(param) if len(params) == 0: - return 0.0 + enable_cuda_kernels = False + else: + enable_cuda_kernels = params[0].grad.device.type == 'cuda' # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) # Parameters can be on CPU or CUDA # If parameters are on CPU, disable CUDA kernerls - enable_cuda_kernels = params[0].grad.device.type == 'cuda' # Calculate norm. if norm_type == inf: @@ -238,7 +247,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type) zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type) - + # If norm is type of float, then we convert them into torch.Tensor. + tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels) + no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels) + zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels) # If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors if not enable_cuda_kernels: tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)