diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py index 8974ff377..8920584a1 100644 --- a/colossalai/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -26,20 +26,20 @@ class _Classifier2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward( - ctx: Any, - A: Tensor, - B: Tensor, - bias, - tesseract_dim: int, - out_shape: Tuple[int, ...], - row_rank: int, - col_rank: int, - row_parallel_mode: ParallelMode, - col_parallel_mode: ParallelMode, - data_parallel_rank: int, - pipeline_parallel_rank: int, - pipeline_parallel_size: int, - tensor_parallel_size: int, + ctx: Any, + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, ) -> Tensor: A_shape = A.shape @@ -166,6 +166,7 @@ class Matmul_AB_2p5D(torch.autograd.Function): :param tensor_parallel_size: tensor parallel size :type tensor_parallel_size: int """ + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, @@ -197,10 +198,14 @@ class Matmul_AB_2p5D(torch.autograd.Function): row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_b = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size opa = [None] * 2 opb = [None] * 2 @@ -295,6 +300,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function): :param tensor_parallel_size: tensor parallel size :type tensor_parallel_size: int """ + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, @@ -323,10 +329,14 @@ class Matmul_ABT_2p5D(torch.autograd.Function): row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_b = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size opb = [None] * 2 opr = [None] * 2 @@ -429,6 +439,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function): :param tensor_parallel_size: tensor parallel size :type tensor_parallel_size: int """ + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int, @@ -457,10 +468,14 @@ class Matmul_ATB_2p5D(torch.autograd.Function): row_group = gpc.get_group(row_parallel_mode) col_group = gpc.get_group(col_parallel_mode) - src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_a = \ + tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = \ + col_rank + tesseract_dim ** 2 * dep_rank + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size opa = [None] * 2 opr = [None] * 2 @@ -540,8 +555,10 @@ class _Add_Bias_2p5D(torch.autograd.Function): bias_temp = bias.clone() else: bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) - src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + src_rank = \ + col_rank + dep_rank * tesseract_dim ** 2 + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode)) ctx.row_rank = row_rank @@ -575,27 +592,37 @@ class _Add_Bias_2p5D(torch.autograd.Function): tensor_parallel_size = ctx.tensor_parallel_size if ctx.bias: - dst_rank = col_rank + dep_rank * ( - tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + dst_rank = \ + col_rank + dep_rank * (tesseract_dim ** 2) + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) if row_rank == 0: - return None, output_grad, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return \ + None, output_grad, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None else: grad_tmp = torch.zeros_like(output_grad) - return None, grad_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return \ + None, grad_tmp, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None else: reduce_dim = tuple(range(output_grad.ndim - 1)) reduce = torch.sum(output_grad, dim=reduce_dim) - dst_rank = col_rank + dep_rank * ( - tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + dst_rank = \ + col_rank + dep_rank * (tesseract_dim ** 2) + \ + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode)) if row_rank == 0: - return output_grad, reduce, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return \ + output_grad, reduce, None, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None else: reduce_tmp = torch.zeros_like(reduce) - return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return \ + output_grad, reduce_tmp, None, None, None, None, None, None, \ + None, None, None, None, None, None, None, None, None def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int, @@ -621,7 +648,8 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode :param col_parallel_mode: column parallel mode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode - :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion + :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, + which is preserved for kernel fusion :type skip_bias_add: bool :param data_parallel_rank: data parallel rank :type data_parallel_rank: int @@ -652,6 +680,7 @@ class _Layernorm2p5D(torch.autograd.Function): :param row_parallel_mode: row parallel mode :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode """ + @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, @@ -748,6 +777,7 @@ class SplitFirst(torch.autograd.Function): :param col_parallel_mode: column parallel mode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode """ + @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: @@ -762,7 +792,7 @@ class SplitFirst(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - grad_shape = (ctx.batch_size, ) + output_grad.shape[1:] + grad_shape = (ctx.batch_size,) + output_grad.shape[1:] grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), @@ -775,10 +805,10 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: :param input_: Input tensor :param dim: Specified dimension in which to split - + :type input_: torch.Tensor :type dim: int, optional - + :return output: Splitted tensor :rtype output: torch.Tensor """ @@ -801,7 +831,7 @@ class _ReduceTensor2p5D(torch.autograd.Function): def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: """ All-reduce the input. - + :param input_: input tensor :param parallel_mode: parallel mode """ @@ -823,7 +853,7 @@ class _ReduceScatterTensor2p5D(torch.autograd.Function): def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: """ Reduce-scatter the input. - + :param input_: input tensor :param parallel_mode: parallel mode """ @@ -868,4 +898,4 @@ def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor: :param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False :type reduce_mean: bool, optional """ - return _RreduceByBatch2p5D.apply(input_, reduce_mean) \ No newline at end of file + return _RreduceByBatch2p5D.apply(input_, reduce_mean) diff --git a/colossalai/nn/layer/parallel_2p5d/_utils.py b/colossalai/nn/layer/parallel_2p5d/_utils.py index bcab619ca..1478b25de 100644 --- a/colossalai/nn/layer/parallel_2p5d/_utils.py +++ b/colossalai/nn/layer/parallel_2p5d/_utils.py @@ -21,4 +21,5 @@ def assert_tesseract_initialization(): gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \ gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \ gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \ - 'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ must be initialized by the process group initializer' + 'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \ + 'must be initialized by the process group initializer' diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index a803f331d..48793ab4e 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -134,8 +134,9 @@ class LayerNorm2p5D(ParallelLayer): r""" Layer Normalization for 2.5D parallelism - :param normalized_shape: input shape from an expected input - of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` + :param normalized_shape: input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. :type normalized_shape: int @@ -431,7 +432,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module): def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None and \ - self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: + self.vocab_start_index <= self.padding_idx < self.vocab_end_index: with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0)