mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 01:12:42 +00:00
fix format for dir-[parallel_3d] (#333)
This commit is contained in:
parent
eaac03ae1d
commit
cbb6436ff0
@ -244,7 +244,7 @@ class _Layernorm3D(torch.autograd.Function):
|
||||
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
"""
|
||||
r"""
|
||||
3D parallel Layernorm
|
||||
|
||||
:param input_: input maxtrix
|
||||
@ -253,8 +253,9 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape:
|
||||
:type weight: torch.tensor
|
||||
:param bias: matrix of bias
|
||||
:type bias: torch.tensor
|
||||
:param normalized_shape: input shape from an expected input
|
||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||
:param normalized_shape: input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
:type normalized_shape: int
|
||||
@ -333,7 +334,7 @@ class _ReduceTensor3D(torch.autograd.Function):
|
||||
|
||||
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
"""
|
||||
All-reduce the input.
|
||||
All-reduce the input
|
||||
|
||||
:param tensor: Input tensor
|
||||
:param parallel_mode: Parallel mode
|
||||
@ -431,7 +432,8 @@ def reduce_by_batch_3d(tensor: Tensor,
|
||||
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||
:param weight_parallel_mode: weight parallel mode
|
||||
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
|
||||
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
|
||||
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size),
|
||||
default to False
|
||||
:type reduce_mean: int, optional
|
||||
"""
|
||||
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
|
||||
|
@ -17,7 +17,8 @@ from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import *
|
||||
from ._operation import layernorm_3d, linear_3d, classifier_3d, split_tensor_3d
|
||||
from ._operation import all_gather_tensor_3d, reduce_scatter_tensor_3d, broadcast_weight_3d_from_diagonal
|
||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||
|
||||
|
||||
@ -26,8 +27,9 @@ class LayerNorm3D(ParallelLayer):
|
||||
r"""
|
||||
Layer Normalization for 3D parallelism
|
||||
|
||||
:param normalized_shape: input shape from an expected input
|
||||
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
|
||||
:param normalized_shape: input shape from an expected input of size.
|
||||
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
|
||||
\times \ldots \times \text{normalized_shape}[-1]]`
|
||||
If a single integer is used, it is treated as a singleton list, and this module will
|
||||
normalize over the last dimension which is expected to be of that specific size.
|
||||
:type normalized_shape: int
|
||||
@ -38,6 +40,7 @@ class LayerNorm3D(ParallelLayer):
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
|
||||
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
|
Loading…
Reference in New Issue
Block a user