fix format for dir-[parallel_3d] (#333)

This commit is contained in:
DouJS 2022-03-09 10:31:43 +08:00 committed by Frank Lee
parent eaac03ae1d
commit cbb6436ff0
2 changed files with 22 additions and 17 deletions

View File

@ -244,7 +244,7 @@ class _Layernorm3D(torch.autograd.Function):
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor: output_parallel_mode: ParallelMode) -> Tensor:
""" r"""
3D parallel Layernorm 3D parallel Layernorm
:param input_: input maxtrix :param input_: input maxtrix
@ -253,8 +253,9 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape:
:type weight: torch.tensor :type weight: torch.tensor
:param bias: matrix of bias :param bias: matrix of bias
:type bias: torch.tensor :type bias: torch.tensor
:param normalized_shape: input shape from an expected input :param normalized_shape: input shape from an expected input of size.
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int :type normalized_shape: int
@ -282,7 +283,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
:type tensor: torch.Tensor :type tensor: torch.Tensor
:type dim: int :type dim: int
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode :type parallel_mode: colossalai.context.parallel_mode.ParallelMode
:return output: Splitted tensor :return output: Splitted tensor
:rtype output: torch.Tensor :rtype output: torch.Tensor
""" """
@ -294,9 +295,9 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
def split_batch_3d(input_: Tensor, def split_batch_3d(input_: Tensor,
dim: int = 0, dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
"""Splits 3D tensor in batch """Splits 3D tensor in batch
:param input_: Input tensor :param input_: Input tensor
:param dim: Specified dimension in which to split :param dim: Specified dimension in which to split
@ -333,8 +334,8 @@ class _ReduceTensor3D(torch.autograd.Function):
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor: def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
""" """
All-reduce the input. All-reduce the input
:param tensor: Input tensor :param tensor: Input tensor
:param parallel_mode: Parallel mode :param parallel_mode: Parallel mode
""" """
@ -359,7 +360,7 @@ class _AllGatherTensor3D(torch.autograd.Function):
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
""" """
All-reduce the gradient in backward pass. All-reduce the gradient in backward pass.
:param tensor: Input tensor :param tensor: Input tensor
:param parallel_mode: Parallel mode :param parallel_mode: Parallel mode
""" """
@ -383,7 +384,7 @@ class _ReduceScatterTensor3D(torch.autograd.Function):
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor: def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
""" """
Reduce-scatter the input. Reduce-scatter the input.
:param tensor: Input tensor :param tensor: Input tensor
:param dim: Dimension to scatter :param dim: Dimension to scatter
:param parallel_mode: Parallel mode :param parallel_mode: Parallel mode
@ -431,7 +432,8 @@ def reduce_by_batch_3d(tensor: Tensor,
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode :param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False :param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size),
default to False
:type reduce_mean: int, optional :type reduce_mean: int, optional
""" """
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean) return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)

View File

@ -17,7 +17,8 @@ from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import * from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d
from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
@ -26,8 +27,9 @@ class LayerNorm3D(ParallelLayer):
r""" r"""
Layer Normalization for 3D parallelism Layer Normalization for 3D parallelism
:param normalized_shape: input shape from an expected input :param normalized_shape: input shape from an expected input of size.
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
\times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size. normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int :type normalized_shape: int
@ -38,6 +40,7 @@ class LayerNorm3D(ParallelLayer):
""" """
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None): def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@ -405,7 +408,7 @@ class PatchEmbedding3D(ParallelLayer):
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1) cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1) output = torch.cat((cls_token, output), dim=1)
@ -549,7 +552,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \ if self.padding_idx is not None and \
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index: self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
with torch.no_grad(): with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0) self.weight[self.padding_idx - self.vocab_start_index].fill_(0)