diff --git a/colossalai/amp/naive_amp/naive_amp.py b/colossalai/amp/naive_amp/naive_amp.py index 8fc1b109b..722e468ce 100644 --- a/colossalai/amp/naive_amp/naive_amp.py +++ b/colossalai/amp/naive_amp/naive_amp.py @@ -32,7 +32,7 @@ class NaiveAMPOptimizer(ColossalaiOptimizer): loss.backward() def step(self): - self.optim.step() + return self.optim.step() def clip_grad_norm(self, model: nn.Module, max_norm: float): pass diff --git a/colossalai/context/parallel_mode.py b/colossalai/context/parallel_mode.py index d50448513..34c3ad475 100644 --- a/colossalai/context/parallel_mode.py +++ b/colossalai/context/parallel_mode.py @@ -26,6 +26,7 @@ class ParallelMode(Enum): # sequence parallel SEQUENCE = 'sequence' + SEQUENCE_DP = 'sequence_dp' # 1D Parallel PARALLEL_1D = '1d' diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/context/process_group_initializer/initializer_sequence.py index 0d0c41e2d..413c92036 100644 --- a/colossalai/context/process_group_initializer/initializer_sequence.py +++ b/colossalai/context/process_group_initializer/initializer_sequence.py @@ -1,5 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import torch.distributed as dist from colossalai.registry import DIST_GROUP_INITIALIZER from .initializer_tensor import Initializer_Tensor @@ -7,6 +8,43 @@ from .process_group_initializer import ProcessGroupInitializer from ..parallel_mode import ParallelMode +@DIST_GROUP_INITIALIZER.register_module +class Initializer_Sequence_DP(ProcessGroupInitializer): + '''A ProcessGroupInitializer for sequence parallelism all-reduce. + + In Sequence Parallelism, each GPU holds the full copy of model weights, + thus, gradient all-reduce occurs across all processes in the same pipeline stage + + ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dp_size = self.world_size // self.pipeline_parallel_size + self.num_group = self.pipeline_parallel_size + + def init_dist_group(self): + '''Initialize Sequence Parallel process groups used for gradient all-reduce. + :return: (local_rank, group_world_size, process_group, ranks_in_group, mode) + :rtype: tuple + ''' + local_rank = None + ranks_in_group = None + process_group = None + group_world_size = None + mode = ParallelMode.SEQUENCE_DP + + for i in range(self.num_group): + ranks = [i * self.dp_size + j for j in range(self.dp_size)] + group = dist.new_group(ranks) + + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + ranks_in_group = ranks + return local_rank, group_world_size, process_group, ranks_in_group, mode + + @DIST_GROUP_INITIALIZER.register_module class Initializer_Sequence(ProcessGroupInitializer): '''A ProcessGroupInitializer for sequence parallelism. @@ -15,13 +53,27 @@ class Initializer_Sequence(ProcessGroupInitializer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # reuse tensor parallel code - self._initializer = Initializer_Tensor(*args, **kwargs) + # reuse tensor parallel initializer code + self._sequence_initializer = Initializer_Tensor(*args, **kwargs) + self._sequence_dp_initializer = Initializer_Sequence_DP(*args, **kwargs) def init_dist_group(self): - local_rank, group_world_size, process_group, ranks_in_group, mode = self._initializer.init_dist_group() + '''Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu. + Sequence parallelism requires 2 process groups. The first is for model forward where several processes + exchange paritial query, key and value embedding to compute self attention values. The second is for + all-reduce to synchronize the model parameters. + + :return: 2D tensor parallelism's information + :rtype: list of tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) + ''' + + parallel_setting = [] + + local_rank, group_world_size, process_group, ranks_in_group, mode = self._sequence_initializer.init_dist_group() # change mode to sequence mode = ParallelMode.SEQUENCE - - return local_rank, group_world_size, process_group, ranks_in_group, mode + + parallel_setting.append((local_rank, group_world_size, process_group, ranks_in_group, mode)) + parallel_setting.append(self._sequence_dp_initializer.init_dist_group()) + return parallel_setting diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 985c9d422..4ecc25cb9 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -77,7 +77,7 @@ class Engine: """ self._all_reduce_gradients() self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm) - self.optimizer.step() + return self.optimizer.step() def backward(self, loss: Tensor): """Start backward propagation given the loss value computed by a loss function diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py index 863bb6b5b..836f1f72b 100644 --- a/colossalai/engine/gradient_handler/__init__.py +++ b/colossalai/engine/gradient_handler/__init__.py @@ -1,9 +1,12 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler from ._zero_gradient_handler import ZeROGradientHandler +from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._moe_gradient_handler import MoeGradientHandler +from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler + __all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', - 'MoeGradientHandler'] + 'MoeGradientHandler', 'SequenceParallelGradientHandler'] diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py new file mode 100644 index 000000000..69563acba --- /dev/null +++ b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +from functools import total_ordering +import torch +import torch.distributed as dist +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from ._base_gradient_handler import BaseGradientHandler +from ...context.parallel_mode import ParallelMode +import colossalai + + +@GRADIENT_HANDLER.register_module +class SequenceParallelGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group. + A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + """ + + def handle_gradient(self): + """A method running a all-reduce operation in a data parallel group. + """ + + # bucketize and all-reduce + buckets = {} + + # Pack the buckets. + for param in self._model.parameters(): + if param.requires_grad and param.grad is not None: + tp = param.data.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + + # For each bucket, all-reduce and copy all-reduced grads. + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + + coalesced /= gpc.get_world_size(ParallelMode.SEQUENCE_DP) + + dist.all_reduce( + coalesced, group=gpc.get_group(ParallelMode.SEQUENCE_DP)) + + for buf, synced in zip(grads, _unflatten_dense_tensors( + coalesced, grads)): + buf.copy_(synced) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 42a585e08..cbd88da79 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -222,7 +222,6 @@ class PipelineSchedule(BaseSchedule): assert forward_only or return_loss, \ 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' - self.load_batch(data_iter) num_warmup_microbatches = \ (gpc.get_world_size(ParallelMode.PIPELINE) - diff --git a/colossalai/initialize.py b/colossalai/initialize.py index ae75ee996..74da408fa 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -17,7 +17,7 @@ from colossalai.core import global_context as gpc from colossalai.engine import Engine from colossalai.logging import get_dist_logger from colossalai.utils import (accumulate_gradient, get_current_device, - sync_model_param_in_dp, is_using_ddp, is_using_pp) + sync_model_param, is_using_ddp, is_using_pp, is_using_sequence) from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3 from colossalai.builder.builder import build_gradient_handler from torch.optim.optimizer import Optimizer @@ -187,7 +187,7 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], backend: str = 'nccl', seed: int = 1024, verbose: bool = True): - '''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size + '''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size from the environment variables set by PyTorch :param config: config file or config file path are both acceptable @@ -270,12 +270,15 @@ def initialize(model: Union[nn.Module, List[nn.Module]], model.to(get_current_device()) use_zero3 = hasattr(gpc.config, 'zero') and gpc.config.zero.level == 3 if not moe_env.is_initialized() and not use_zero3: - sync_model_param_in_dp(model) + if is_using_sequence(): + sync_model_param(model, ParallelMode.SEQUENCE_DP) + elif is_using_ddp(): + sync_model_param(model, ParallelMode.DATA) else: - print( - "Warning: The parameters of models is not automatically synchronized.\n" + logger.warning( + "The parameters of models is not automatically synchronized.\n" "Please make sure that all parameters are the same in data parallel group.", - flush=True) + ranks=[0]) # check amp and zero fp16_cfg = gpc.config.get('fp16', None) @@ -339,11 +342,16 @@ def initialize(model: Union[nn.Module, List[nn.Module]], "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically " "added even though not specified in the configuration", ranks=[0]) + elif is_using_sequence(): + model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP)) + if verbose: + logger.info( + 'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0]) elif is_using_ddp() and not is_using_pp() and amp_mode != AMP_TYPE.NAIVE: model = DDP(model, process_group=gpc.get_group(ParallelMode.DATA)) if verbose: logger.info( - 'Model is using torch.nn.parallel.DistributedDataParallel', ranks=[0]) + 'Model is using torch.nn.parallel.DistributedDataParallel for Data Parallelism', ranks=[0]) elif is_using_ddp(): gradient_handler_cfg = [dict(type='DataParallelGradientHandler')] if verbose: diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 4e1d486d9..a45a3e7ae 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -6,6 +6,7 @@ import numbers import torch from torch.nn.parameter import Parameter from torch.nn import init +from torch.cuda.amp import custom_fwd, custom_bwd import importlib global colossal_layer_norm_cuda @@ -15,6 +16,7 @@ colossal_layer_norm_cuda = None class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod + @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input, weight, bias, normalized_shape, eps): ctx.normalized_shape = normalized_shape @@ -29,6 +31,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): return output @staticmethod + @custom_bwd def backward(ctx, grad_output): input_, weight_, bias_, mean, invvar = ctx.saved_tensors @@ -71,3 +74,6 @@ class MixedFusedLayerNorm(torch.nn.Module): return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps) + + def __repr__(self): + return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})' diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index c21789726..d95905897 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -6,7 +6,7 @@ JIT_OPTIONS_SET = False def set_jit_fusion_options(): """Set PyTorch JIT layer fusion options. """ - # LSG: the latest pytorch and CUDA versions may not support + # LSG: the latest pytorch and CUDA versions may not support # the following jit settings global JIT_OPTIONS_SET if JIT_OPTIONS_SET == False: diff --git a/colossalai/nn/layer/parallel_sequence/_operation.py b/colossalai/nn/layer/parallel_sequence/_operation.py index d5f65d5d8..119302a09 100644 --- a/colossalai/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/nn/layer/parallel_sequence/_operation.py @@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range from colossalai.utils import get_current_device +from torch.cuda.amp import custom_bwd, custom_fwd class RingQK(torch.autograd.Function): @@ -17,6 +18,7 @@ class RingQK(torch.autograd.Function): """ @staticmethod + @custom_fwd def forward(ctx, sub_q, sub_k, @@ -54,6 +56,7 @@ class RingQK(torch.autograd.Function): return attention_score @staticmethod + @custom_bwd def backward(ctx, grad_output): sub_q, sub_k, = ctx.saved_tensors local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) @@ -64,6 +67,7 @@ class RingQK(torch.autograd.Function): grad_output.transpose(2, 1), sub_q ) + dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE)) grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length] grad_k /= local_world_size @@ -94,6 +98,7 @@ class RingAV(torch.autograd.Function): """ @staticmethod + @custom_fwd def forward(ctx, attention_score, sub_v, @@ -131,6 +136,7 @@ class RingAV(torch.autograd.Function): return sub_attention_result @staticmethod + @custom_bwd def backward(ctx, grad_output): local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE) diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/nn/layer/parallel_sequence/layers.py index 132fc3dcc..3e87f10f0 100644 --- a/colossalai/nn/layer/parallel_sequence/layers.py +++ b/colossalai/nn/layer/parallel_sequence/layers.py @@ -2,15 +2,20 @@ # -*- encoding: utf-8 -*- import math +import colossalai import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn import Parameter from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV from colossalai.registry import LAYERS +from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +from colossalai.kernel import FusedScaleMaskSoftmax +from colossalai.context import seed @LAYERS.register_module @@ -31,136 +36,144 @@ class TransformerSelfAttentionRing(nn.Module): def __init__(self, hidden_size, - kv_channels, num_attention_heads, attention_dropout, + attention_mask_func, + layer_number, + apply_query_key_layer_scaling: bool = False, + convert_fp16_to_fp32_in_softmax: bool = False, + attn_mask_type=AttnMaskType.padding, + masked_softmax_fusion=True, + fp16=False, + bf16=False ): super().__init__() - + self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_mask_func = attention_mask_func + self.layer_number = layer_number self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads + self.attn_mask_type = attn_mask_type + assert self.layer_number > 0 + self.attention_dropout = attention_dropout - projection_size = kv_channels * num_attention_heads - self.hidden_size_per_attention_head = projection_size // num_attention_heads + if self.apply_query_key_layer_scaling: + self.convert_fp16_to_fp32_in_softmax = True + + assert self.hidden_size % self.num_attention_heads == 0, \ + 'hidden size is not divisible by the number of attention heads' + + self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE) # Strided linear layer. - self.query_key_value = nn.Linear( + self.query_key_value = _Linear( hidden_size, - 3 * projection_size, + 3 * self.hidden_size, ) - # coeff = None + self.coeff = None self.norm_factor = math.sqrt(self.hidden_size) - # TODO: add apply_query_key_layer_scaling when we have the kernel module - # if self.apply_query_key_layer_scaling: - # coeff = self.layer_number - # self.norm_factor *= coeff + if self.apply_query_key_layer_scaling: + self.coeff = layer_number + self.norm_factor *= self.coeff - # TODO: add fused scale mask softmax kernel when we have the kernel module - # self.scale_mask_softmax = FusedScaleMaskSoftmax( - # self.fp16, self.bf16, - # self.attn_mask_type, - # masked_softmax_fusion, - # attention_mask_func, - # self.attention_softmax_in_fp32, - # coeff) + self.scale_mask_softmax = FusedScaleMaskSoftmax( + fp16, bf16, + self.attn_mask_type, + masked_softmax_fusion, + self.attention_mask_func, + self.convert_fp16_to_fp32_in_softmax, + self.coeff) self.attention_dropout = nn.Dropout(attention_dropout) # Output. - self.dense = nn.Linear( - projection_size, - hidden_size, - bias=True) + self.dense = _Linear(hidden_size, + hidden_size, + bias=True, + skip_bias_add=True) def forward(self, hidden_states, attention_mask): - # hidden_states: [sq, b, h] - + # hidden_states: [sub_seq_len, batch_size, hidden_size] + # attention_mask: [batch_size, 1, sub_seq_len, seq_len] sub_seq_length, batch_size, hidden_size = hidden_states.size() # ===================== # Query, Key, and Value # ===================== - # Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)] + # Attention heads shape change: + # [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)] mixed_x_layer = self.query_key_value(hidden_states) - # [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn] + # [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size] new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # split into query, key and value last_dim = mixed_x_layer.dim() - 1 - last_dim_value = mixed_x_layer.size()[-1] + last_dim_value = mixed_x_layer.size(-1) assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \ 'cannot be divided into query, key and value' partition_size = last_dim_value // 3 (query_layer, key_layer, value_layer) = torch.split( mixed_x_layer, partition_size, dim=last_dim) - # =================================== - # Raw attention scores. [b, num_heads, s, s] - # =================================== - - # [b, num_heads, sq, sk] + # attention scores: [batch_size, num_heads, sub_seq_len, seq_len] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0) * self.world_size) - # [sq, b, num_heads, hn] -> [sq, b * num_heads, hn] + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, num_heads, hn] -> [sk, b * num_heads, hn] + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size] key_layer = key_layer.view(key_layer.size(0), output_size[0] * output_size[1], -1) - # [b, sq, sk] + # attention_scores: [batch_size * num_heads, sub_seq_len, seq_len] attention_scores = RingQK.apply( - # [b * num_heads, sq, hn] - query_layer.transpose(0, 1).contiguous(), - key_layer.transpose(0, 1).contiguous(), # [b * num_heads, sk, hn], + query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size] + key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size], batch_size, self.num_attention_heads, sub_seq_length ) + attention_scores /= self.norm_factor - # change view to [b, num_heads, sq, sk] + # change view to [batch_size, num_heads, sub_seq_len, seq_len] attention_scores = attention_scores.view(*output_size) - attention_scores = attention_scores.unsqueeze(1) - - attention_scores = attention_scores + attention_mask - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.squeeze(1) + # change shape to [batch_size, num_heads, sub_seq_len, seq_len] + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - # with mpu.get_cuda_rng_tracker().fork(): - # TODO: check if a rng tracker is needed - attention_probs = self.attention_dropout(attention_probs) + with seed(ParallelMode.TENSOR): + attention_probs = self.attention_dropout(attention_probs) - # context layer shape: [b, num_heads, sq, hn] + # context layer shape: [batch_size, num_heads, sub_seq_len, head_size] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # - # # change view [sk, b * num_heads, hn] + + # change view [sub_seq_len, batch_size * num_heads, head_size] value_layer = value_layer.contiguous().view(value_layer.size(0), output_size[0] * output_size[1], -1) - # # change view [b * num_heads, sq, sk] + # # change view [b * num_heads, sub_seq_len, seq_len] attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1), attention_probs.size(2), attention_probs.size(3)) - # matmul: [b*num_heads, sq, hn] - # context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # matmul: [batch_size * num_heads, sub_seq_len, head_size] context_layer = RingAV.apply( attention_probs, value_layer.transpose(0, 1).contiguous(), @@ -170,19 +183,83 @@ class TransformerSelfAttentionRing(nn.Module): sub_seq_length ) - # # change view [b, num_heads, sq, hn] + # change view [batch_size, num_heads, sub_seq_len, head_size] context_layer = context_layer.view(*output_size) - # # [b, np, sq, hn] --> [sq, b, np, hn] + # [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # # [sq, b, np, hn] --> [sq, b, hp] + # [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size] new_context_layer_shape = context_layer.size()[:-2] + ( self.hidden_size_per_attention_head * self.num_attention_heads,) context_layer = context_layer.view(*new_context_layer_shape) - # context_layer = context_layer.transpose(1, 0).contiguous() - output = self.dense(context_layer) - bias = self.dense.bias + output, bias = self.dense(context_layer) return output, bias + + def __repr__(self): + return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \ + f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \ + f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \ + f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \ + f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})' + + +class _Linear(nn.Module): + """Linear layer with column parallelism. + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + skip_bias_add: This was added to enable performance optimations where bias + can be fused with other elementwise operations. we skip + adding bias but instead return it. + """ + + def __init__(self, + input_size, + output_size, + bias=True, + skip_bias_add=False): + super(_Linear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + + self.weight = Parameter(torch.empty(self.output_size, + self.input_size, + )) + nn.init.xavier_normal_(self.weight) + + if bias: + self.bias = Parameter(torch.empty(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + # Matrix multiply. + bias = self.bias if not self.skip_bias_add else None + output = F.linear(input_, self.weight, bias) + + if self.skip_bias_add: + return output, self.bias + else: + return output + + def __repr__(self): + return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \ + f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})' diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index b346f1a57..2ce181954 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -1,8 +1,8 @@ from .activation_checkpoint import checkpoint from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0, - is_using_ddp, is_using_pp, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, - print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param_in_dp) + is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, param_is_not_tensor_parallel_duplicate, + print_rank_0, switch_virtual_pipeline_parallel_rank, sync_model_param) from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize from .data_sampler import DataParallelSampler, get_dataloader from .gradient_accumulation import accumulate_gradient @@ -10,9 +10,9 @@ from .memory import report_memory_usage from .timer import MultiTimer, Timer __all__ = [ - 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param_in_dp', 'is_dp_rank_0', 'is_tp_rank_0', - 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'conditional_context', 'is_model_parallel_parameter', - 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', + 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', + 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context', + 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank' diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index a93818b07..c5c4fbab2 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -47,16 +47,16 @@ def free_port(): continue -def sync_model_param_in_dp(model): +def sync_model_param(model, parallel_mode): '''Make sure data parameters are consistent during Data Parallel Mode :param model: A pyTorch nn.model on whose parameters you check the consistency ''' - if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: for param in model.parameters(): - ranks = gpc.get_ranks_in_group(ParallelMode.DATA) + ranks = gpc.get_ranks_in_group(parallel_mode) dist.broadcast( - param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA)) + param, src=ranks[0], group=gpc.get_group(parallel_mode)) def is_dp_rank_0(): @@ -79,6 +79,10 @@ def is_using_pp(): return gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1 +def is_using_sequence(): + return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1 + + @contextmanager def conditional_context(context_manager, enable=True): if enable: @@ -240,16 +244,20 @@ def count_zeros_fp32(parameters): num_zeros = grad.numel() - torch.count_nonzero(grad) total_num_zeros = num_zeros + total_num_zeros + total_num_zeros = torch.IntTensor([int(total_num_zeros)]).cuda() + # Sum across all model-parallel GPUs. ops = [] ops.append(dist.all_reduce(total_num_zeros, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR), async_op=True)) - ops.append(dist.all_reduce(total_num_zeros, - op=dist.ReduceOp.SUM, - group=gpc.get_group(ParallelMode.PIPELINE), - async_op=True)) + if gpc.is_initialized(ParallelMode.PIPELINE): + ops.append(dist.all_reduce(total_num_zeros, + op=dist.ReduceOp.SUM, + group=gpc.get_group(ParallelMode.PIPELINE), + async_op=True)) + for req in ops: req.wait() total_num_zeros = total_num_zeros.item() diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py index c1a711c2c..712f97d96 100644 --- a/colossalai/utils/memory.py +++ b/colossalai/utils/memory.py @@ -40,8 +40,6 @@ def report_memory_usage(message, logger=None, report_cpu=False): :type report_cpu: bool :raises EnvironmentError: raise error if no distributed environment has been initialized ''' - if not gpc.is_initialized(ParallelMode.GLOBAL): - raise EnvironmentError("No distributed environment is initialized") gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated()) gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated()) diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 1d121d5de..1ca4f8e86 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- import time +from typing import Tuple from .cuda import synchronize @@ -8,6 +9,7 @@ class Timer: ''' A timer object which helps to log the execution times, and provides different tools to assess the times. ''' + def __init__(self): self._started = False self._start_time = time.time() @@ -129,6 +131,6 @@ class MultiTimer: def set_status(self, mode: bool): self._on = mode - def __iter__(self): + def __iter__(self) -> Tuple[str, Timer]: for name, timer in self._timers.items(): - yield name, timer \ No newline at end of file + yield name, timer diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py index 1ee104eb2..00e008460 100644 --- a/tests/test_layers/test_sequence/test_sequence.py +++ b/tests/test_layers/test_sequence/test_sequence.py @@ -1,48 +1,150 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest +import colossalai +import colossalai.nn as col_nn import torch +import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.initialize import launch -from colossalai.logging import get_dist_logger -from checks_seq.check_layer_seq import * +import pytest + +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode from functools import partial -from colossalai.utils import free_port CONFIG = dict( parallel=dict( - pipeline=1, - tensor=dict(mode='sequence', size=4) + tensor=dict(size=4, mode='sequence') ) ) -def check_layer(): - check_selfattention() +def check_ring_qk(rank, world_size): + # params + batch_size = 4 + num_heads = 4 + seq_length = 32 + attention_head_size = 32 + sub_seq_length = seq_length // world_size + + # create master tensors + q = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda() + k = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda() + dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + + # create distributed tensors + sub_q = q.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous() + sub_k = k.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous() + + # set autograd attributes + q.requires_grad = True + k.requires_grad = True + q.retain_grad() + k.retain_grad() + sub_q.requires_grad = True + sub_k.requires_grad = True + sub_q.retain_grad() + sub_k.retain_grad() + + # compute master attention scores + a = torch.matmul(q, k.transpose(2, 1)) + + # compute distributed attention scores + ring_qk = colossalai.nn.layer.parallel_sequence.RingQK.apply + sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length) + + # check master and distributed attetion scores + sub_master_a = a[:, rank*sub_seq_length:(rank+1)*sub_seq_length] + assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2) + + # run master backward + a.retain_grad() + a.mean().backward() + + # run distributed backward + partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length] + torch.autograd.backward(sub_a, partial_master_a_grad) + + # check master and distributed grads + partial_master_q_grad = q.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length] + assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \ + 'attention score cannot match' -def run_check_sequence(rank, world_size, port): - # init dist - launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - logger = get_dist_logger() - logger.info('Distributed environment is initialzied.', ranks=[0]) +def check_ring_av(rank, world_size): + # params + batch_size = 4 + num_heads = 4 + seq_length = 16 + attention_head_size = 32 + sub_seq_length = seq_length // world_size - # check layers - check_layer() + # create master tensors + a = torch.rand(batch_size*num_heads, seq_length, seq_length).cuda() + v = torch.rand(batch_size*num_heads, seq_length, attention_head_size).cuda() + dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE)) + + # create distributed tensors + sub_a = a.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous() + sub_v = v.clone()[:, rank*sub_seq_length:(rank+1)*sub_seq_length].contiguous() + + # set autograd attributes + a.requires_grad = True + v.requires_grad = True + a.retain_grad() + v.retain_grad() + sub_a.requires_grad = True + sub_v.requires_grad = True + sub_a.retain_grad() + sub_v.retain_grad() + + # compute master attention scores + out = torch.matmul(a, v) + + # compute distributed attention scores + ring_av = colossalai.nn.layer.parallel_sequence.RingAV.apply + sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length) + + # print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}') + + # check master and distributed output + sub_master_out = out[:, rank*sub_seq_length:(rank+1)*sub_seq_length] + assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2) + + # # run master backward + out.retain_grad() + out.mean().backward() + + # # run distributed backward + partial_master_out_grad = out.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length] + torch.autograd.backward(sub_out, partial_master_out_grad) + + # # check master and distributed grads + partial_master_a_grad = a.grad[:, rank*sub_seq_length:(rank+1)*sub_seq_length] + assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \ + 'attention output cannot match' + + +def run_test(rank, world_size): + colossalai.launch( + rank=rank, + world_size=world_size, + config=CONFIG, + host='localhost', + port=29500 + ) + + # check_ring_qk(rank, world_size) + check_ring_av(rank, world_size) + + gpc.destroy() torch.cuda.empty_cache() @pytest.mark.dist def test_sequence(): world_size = 4 - run_func = partial(run_check_sequence, world_size=world_size, port=free_port()) + run_func = partial(run_test, world_size=world_size) mp.spawn(run_func, nprocs=world_size)