mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 01:12:42 +00:00
fix format for dir-[parallel_3d] (#333)
This commit is contained in:
parent
eaac03ae1d
commit
cbb6436ff0
@ -244,7 +244,7 @@ class _Layernorm3D(torch.autograd.Function):
|
|||||||
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
||||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||||
output_parallel_mode: ParallelMode) -> Tensor:
|
output_parallel_mode: ParallelMode) -> Tensor:
|
||||||
"""
|
r"""
|
||||||
3D parallel Layernorm
|
3D parallel Layernorm
|
||||||
|
|
||||||
:param input_: input maxtrix
|
:param input_: input maxtrix
|
||||||
@ -253,8 +253,9 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape:
|
|||||||
:type weight: torch.tensor
|
:type weight: torch.tensor
|
||||||
:param bias: matrix of bias
|
:param bias: matrix of bias
|
||||||
:type bias: torch.tensor
|
:type bias: torch.tensor
|
||||||
:param normalized_shape: input shape from an expected input
|
:param normalized_shape: input shape from an expected input of size.
|
||||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||||
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
:type normalized_shape: int
|
:type normalized_shape: int
|
||||||
@ -282,7 +283,7 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
|||||||
:type tensor: torch.Tensor
|
:type tensor: torch.Tensor
|
||||||
:type dim: int
|
:type dim: int
|
||||||
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
|
|
||||||
:return output: Splitted tensor
|
:return output: Splitted tensor
|
||||||
:rtype output: torch.Tensor
|
:rtype output: torch.Tensor
|
||||||
"""
|
"""
|
||||||
@ -294,9 +295,9 @@ def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Te
|
|||||||
|
|
||||||
|
|
||||||
def split_batch_3d(input_: Tensor,
|
def split_batch_3d(input_: Tensor,
|
||||||
dim: int = 0,
|
dim: int = 0,
|
||||||
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
||||||
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
||||||
"""Splits 3D tensor in batch
|
"""Splits 3D tensor in batch
|
||||||
:param input_: Input tensor
|
:param input_: Input tensor
|
||||||
:param dim: Specified dimension in which to split
|
:param dim: Specified dimension in which to split
|
||||||
@ -333,8 +334,8 @@ class _ReduceTensor3D(torch.autograd.Function):
|
|||||||
|
|
||||||
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||||
"""
|
"""
|
||||||
All-reduce the input.
|
All-reduce the input
|
||||||
|
|
||||||
:param tensor: Input tensor
|
:param tensor: Input tensor
|
||||||
:param parallel_mode: Parallel mode
|
:param parallel_mode: Parallel mode
|
||||||
"""
|
"""
|
||||||
@ -359,7 +360,7 @@ class _AllGatherTensor3D(torch.autograd.Function):
|
|||||||
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
def all_gather_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||||
"""
|
"""
|
||||||
All-reduce the gradient in backward pass.
|
All-reduce the gradient in backward pass.
|
||||||
|
|
||||||
:param tensor: Input tensor
|
:param tensor: Input tensor
|
||||||
:param parallel_mode: Parallel mode
|
:param parallel_mode: Parallel mode
|
||||||
"""
|
"""
|
||||||
@ -383,7 +384,7 @@ class _ReduceScatterTensor3D(torch.autograd.Function):
|
|||||||
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Reduce-scatter the input.
|
Reduce-scatter the input.
|
||||||
|
|
||||||
:param tensor: Input tensor
|
:param tensor: Input tensor
|
||||||
:param dim: Dimension to scatter
|
:param dim: Dimension to scatter
|
||||||
:param parallel_mode: Parallel mode
|
:param parallel_mode: Parallel mode
|
||||||
@ -431,7 +432,8 @@ def reduce_by_batch_3d(tensor: Tensor,
|
|||||||
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
:param weight_parallel_mode: weight parallel mode
|
:param weight_parallel_mode: weight parallel mode
|
||||||
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||||
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
|
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size),
|
||||||
|
default to False
|
||||||
:type reduce_mean: int, optional
|
:type reduce_mean: int, optional
|
||||||
"""
|
"""
|
||||||
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
|
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
|
||||||
|
@ -17,7 +17,8 @@ from torch import Tensor
|
|||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ._operation import *
|
from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d
|
||||||
|
from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal
|
||||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||||
|
|
||||||
|
|
||||||
@ -26,8 +27,9 @@ class LayerNorm3D(ParallelLayer):
|
|||||||
r"""
|
r"""
|
||||||
Layer Normalization for 3D parallelism
|
Layer Normalization for 3D parallelism
|
||||||
|
|
||||||
:param normalized_shape: input shape from an expected input
|
:param normalized_shape: input shape from an expected input of size.
|
||||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||||
|
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
:type normalized_shape: int
|
:type normalized_shape: int
|
||||||
@ -38,6 +40,7 @@ class LayerNorm3D(ParallelLayer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
|
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||||
@ -405,7 +408,7 @@ class PatchEmbedding3D(ParallelLayer):
|
|||||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||||
if self.flatten:
|
if self.flatten:
|
||||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
|
|
||||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||||
output = torch.cat((cls_token, output), dim=1)
|
output = torch.cat((cls_token, output), dim=1)
|
||||||
@ -549,7 +552,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
|||||||
|
|
||||||
def _fill_padding_idx_with_zero(self) -> None:
|
def _fill_padding_idx_with_zero(self) -> None:
|
||||||
if self.padding_idx is not None and \
|
if self.padding_idx is not None and \
|
||||||
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
self.padding_idx >= self.vocab_start_index and self.padding_idx < self.vocab_end_index:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user