updated tp layers

This commit is contained in:
kurisusnowdeng
2022-10-26 20:54:39 +08:00
committed by アマデウス
parent cb5a587e9a
commit 0b8161fab8
13 changed files with 645 additions and 293 deletions

View File

@@ -1,4 +1,6 @@
import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
try:
import fused_mix_prec_layer_norm_cuda
@@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class LinearWithAsyncCommunication(torch.autograd.Function):
"""
Linear layer execution with asynchronous communication in backprop.
"""
@staticmethod
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.parallel_mode = parallel_mode
ctx.async_grad_allreduce = async_grad_allreduce
output = torch.matmul(input_, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
total_input = input
grad_input = grad_output.matmul(weight)
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)

View File

@@ -20,12 +20,12 @@ from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward)
from ._operation import linear_with_async_comm
@LAYERS.register_module
@@ -96,8 +96,25 @@ class LayerNorm1D(ColossalaiModule):
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
"""
_fast_ln_supported_sizes = [
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536
]
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
from apex.normalization import FusedLayerNorm
fast_ln_installed = False
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
fast_ln_installed = True
except ImportError:
pass
if fast_ln_installed and normalized_shape in self._fast_ln_supported_sizes:
norm = FastLayerNorm(normalized_shape, eps=eps).to(dtype)
else:
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args):
@@ -519,11 +536,12 @@ class Linear1D_Col(ParallelLayer):
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
# output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
@@ -665,6 +683,7 @@ class Linear1D_Row(ParallelLayer):
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add:

View File

@@ -9,7 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env
from ._utils import get_parallel_mode_from_env, push_async_grad
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
@@ -17,34 +17,27 @@ class _Linear3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
ctx.use_bias = bias is not None
input_ = all_gather(input_, input_dim, input_parallel_mode)
weight = all_gather(weight, weight_dim, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, output_dim, output_parallel_mode)
if bias is not None:
output += bias
def forward(
ctx,
input_: Tensor,
weight: Tensor,
weight_id: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.weight_id = weight_id
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_dim = input_dim
ctx.weight_dim = weight_dim
ctx.output_dim = output_dim
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight, -1, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, 0, output_parallel_mode)
return output
@staticmethod
@@ -52,73 +45,70 @@ class _Linear3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode)
async_ops = list()
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True)
async_ops.append(op)
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
else:
bias_grad = None
input_op.wait()
for op in async_ops:
if op is not None:
op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
return input_grad, weight_grad, None, None, None, None
def linear_3d(input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
def linear_3d(
input_: Tensor,
weight: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""Linear layer for 3D parallelism.
Args:
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
bias (:class:`torch.tensor`): matrix of bias.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
input_dim (int, optional): dimension of input, defaults to 0.
weight_dim (int, optional): dimension of weight, defaults to -1.
output_dim (int, optional): dimension of output, defaults to 0.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode,
input_dim, weight_dim, output_dim)
return _Linear3D.apply(
input_,
weight,
id(weight),
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _Classifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
weight_id: int,
bias_id: Optional[int],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.use_bias = bias is not None
ctx.weight_id = weight_id
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)]
weight = broadcast(weight, src_rank, input_parallel_mode)
ctx.save_for_backward(input_, weight)
@@ -126,6 +116,7 @@ class _Classifier3D(torch.autograd.Function):
output = all_reduce(output, output_parallel_mode)
if bias is not None:
ctx.bias_id = bias_id
output += bias
ctx.src_rank = src_rank
@@ -139,14 +130,12 @@ class _Classifier3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
async_ops = list()
weight_grad = torch.matmul(
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
else:
weight_grad = None
@@ -154,21 +143,23 @@ class _Classifier3D(torch.autograd.Function):
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
input_grad = torch.matmul(output_grad, weight)
for op in async_ops:
if op is not None:
op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
def classifier_3d(
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D parallel classifier.
Args:
@@ -183,16 +174,134 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode)
return _Classifier3D.apply(
input_,
weight,
bias,
id(weight),
id(bias) if bias is not None else None,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _VocabParallelClassifier3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
weight_id: int,
bias_id: Optional[int],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
ctx.use_bias = bias is not None
ctx.weight_id = weight_id
input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight)
output = reduce_scatter(output, 0, output_parallel_mode)
if bias is not None:
ctx.bias_id = bias_id
output += bias
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_, weight = ctx.saved_tensors
with torch.no_grad():
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
else:
bias_grad = None
input_op.wait()
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def vocab_parallel_classifier_3d(
input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D vocab parallel classifier.
Args:
input_ (:class:`torch.tensor`): input matrix.
weight (:class:`torch.tensor`): matrix of weight.
bias (:class:`torch.tensor`): matrix of bias.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _VocabParallelClassifier3D.apply(
input_,
weight,
bias,
id(weight),
id(bias) if bias is not None else None,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
)
class _Layernorm3D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
def forward(
ctx,
input_: Tensor,
weight: Tensor,
bias: Tensor,
weight_id: int,
bias_id: int,
normalized_shape: int,
eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
ctx.weight_id = weight_id
ctx.bias_id = bias_id
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
mu = input_ - mean
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
@@ -201,15 +310,13 @@ class _Layernorm3D(torch.autograd.Function):
ctx.save_for_backward(mu, sigma, weight)
z = mu / sigma
output = weight * z
if bias is not None:
output = output + bias
output = weight * z + bias
ctx.use_bias = bias is not None
ctx.normalized_shape = normalized_shape
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
return output
@@ -218,17 +325,14 @@ class _Layernorm3D(torch.autograd.Function):
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
mu, sigma, weight = ctx.saved_tensors
with torch.no_grad():
weight_grad = output_grad * mu / sigma
if ctx.use_bias:
bias_grad = output_grad
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
else:
bias_grad = None
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
if ctx.use_bias:
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
dz = output_grad * weight
dvar = dz * mu * (-0.5) * sigma**(-3)
@@ -236,15 +340,22 @@ class _Layernorm3D(torch.autograd.Function):
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
input_grad = dz / sigma + dvar * 2 * mu / \
ctx.normalized_shape + dmean / ctx.normalized_shape
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
return input_grad, weight_grad, bias_grad, None, None, None, None, None
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
def layernorm_3d(
input_: Tensor,
weight: Tensor,
bias: Tensor,
normalized_shape: int,
eps: float,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_x_weight_parallel_mode: ParallelMode,
) -> Tensor:
r"""3D parallel Layernorm.
Args:
@@ -265,8 +376,19 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)
return _Layernorm3D.apply(
input_,
weight,
bias,
id(weight),
id(bias),
normalized_shape,
eps,
input_parallel_mode,
weight_parallel_mode,
output_parallel_mode,
input_x_weight_parallel_mode,
)
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
@@ -315,17 +437,12 @@ def split_batch_3d(input_: Tensor,
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
dim_size = input_.size(dim)
if input_.size(dim) <= 1:
return input_
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_world_size = gpc.get_world_size(weight_parallel_mode)
input_world_size = gpc.get_world_size(input_parallel_mode)
assert dim_size % (input_world_size*weight_world_size) == 0, \
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
return output
@@ -464,47 +581,3 @@ def reduce_by_batch_3d(tensor: Tensor,
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
r"""broadcast weight from diagonal.
Args:
input_ (:class:`torch.tensor`): input matrix.
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode) -> Tensor:
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
output = broadcast(input_, src_rank, input_parallel_mode)
ctx.src_rank = src_rank
ctx.input_parallel_mode = input_parallel_mode
ctx.weight_parallel_mode = weight_parallel_mode
ctx.output_parallel_mode = output_parallel_mode
return output
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
else:
input_grad = None
return input_grad, None, None, None
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)

View File

@@ -1,8 +1,13 @@
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from collections import OrderedDict
from functools import partial
import torch
from torch import Tensor
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from torch import Tensor
def get_depth_from_env() -> int:
@@ -17,30 +22,17 @@ def get_depth_from_env() -> int:
def get_parallel_mode_from_env(group):
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \
f'{group} is not valid for 3D tensor parallelism.'
return getattr(env, group)
def get_last_group(a, b):
mapping = {
ParallelMode.PARALLEL_3D_INPUT: 'A',
ParallelMode.PARALLEL_3D_WEIGHT: 'B',
ParallelMode.PARALLEL_3D_OUTPUT: 'C',
}
res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
if res == 'A':
return ParallelMode.PARALLEL_3D_INPUT
elif res == 'B':
return ParallelMode.PARALLEL_3D_WEIGHT
elif res == 'C':
return ParallelMode.PARALLEL_3D_OUTPUT
def swap_in_out_group():
env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d
env.input_x_weight_group_3d, env.output_x_weight_group_3d = (
env.output_x_weight_group_3d,
env.input_x_weight_group_3d,
)
def dbg_check_shape(tensor: Tensor, shape: tuple):
@@ -49,3 +41,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
print(tensor.shape)
assert tensor.shape == shape, \
'{} does not match {}'.format(tensor.shape, shape)
class AsyncGradientBucket(object):
def __init__(self):
self.bucket = OrderedDict()
def __len__(self):
return len(self.bucket)
def push(self, async_op, grad_tensor, param_id):
self.bucket[param_id] = tuple((async_op, grad_tensor))
return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device)
def pop(self, param_id):
grad = None
if param_id in self.bucket:
op, grad = self.bucket.pop(param_id)
if op is not None:
op.wait()
return grad
def synchronize(self, params):
for p in params:
i = id(p)
if i in self.bucket:
op, grad = self.bucket.pop(i)
if op is not None:
op.wait()
p.grad.add_(grad)
_async_grad_bucket = AsyncGradientBucket()
def push_async_grad(op, grad, param_id):
return _async_grad_bucket.push(op, grad, param_id)
def pop_async_grad(param_id):
return _async_grad_bucket.pop(param_id)
def _async_grad_hook(grad, param_id):
grad.add_(pop_async_grad(param_id))
return grad
def register_async_grad_hook(param):
param.register_hook(partial(_async_grad_hook, param_id=id(param)))
def synchronize(params=list()):
_async_grad_bucket.synchronize(params)
torch.cuda.default_stream().synchronize()
if len(_async_grad_bucket) > 0:
raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.")

View File

@@ -6,7 +6,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
@@ -20,9 +20,9 @@ from torch import Tensor
from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d,
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d)
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d,
reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d)
from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook
@LAYERS.register_module
@@ -45,7 +45,8 @@ class LayerNorm3D(ParallelLayer):
super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
@@ -58,6 +59,7 @@ class LayerNorm3D(ParallelLayer):
else:
self.bias = None
self.variance_epsilon = eps
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self) -> None:
@@ -67,8 +69,10 @@ class LayerNorm3D(ParallelLayer):
def reset_parameters(self) -> None:
init.ones_()(self.weight)
register_async_grad_hook(self.weight)
if self.bias is not None:
init.zeros_()(self.bias)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@@ -134,8 +138,17 @@ class LayerNorm3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
return layernorm_3d(
input_,
self.weight,
self.bias,
self.normalized_shape,
self.variance_epsilon,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
self.input_x_weight_parallel_mode,
)
@LAYERS.register_module
@@ -161,6 +174,7 @@ class Linear3D(ParallelLayer):
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
@@ -168,8 +182,10 @@ class Linear3D(ParallelLayer):
self.out_features = out_features
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.skip_bias_add = skip_bias_add
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2)
self.bias_features_per_partition = divide(out_features, self.depth)
@@ -194,18 +210,23 @@ class Linear3D(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
self.bias.register_hook(self._sync_grad_hook)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@@ -324,8 +345,20 @@ class Linear3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
output = linear_3d(
input_,
self.weight,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
if not self.skip_bias_add:
if self.bias is not None:
output = output + self.bias
return output
else:
return output, self.bias
@LAYERS.register_module
@@ -360,7 +393,7 @@ class Classifier3D(ParallelLayer):
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
@@ -386,19 +419,17 @@ class Classifier3D(ParallelLayer):
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@@ -468,8 +499,14 @@ class Classifier3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
return classifier_3d(
input_,
self.weight,
self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
@LAYERS.register_module
@@ -504,7 +541,8 @@ class VocabParallelClassifier3D(ParallelLayer):
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(num_classes, self.depth**2)
@@ -544,12 +582,14 @@ class VocabParallelClassifier3D(ParallelLayer):
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
register_async_grad_hook(self.weight)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias,
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
self.output_x_weight_parallel_mode)
register_async_grad_hook(self.bias)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
@@ -668,8 +708,14 @@ class VocabParallelClassifier3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
return vocab_parallel_classifier_3d(
input_,
self.weight,
self.bias,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode,
)
@LAYERS.register_module
@@ -708,12 +754,16 @@ class PatchEmbedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.patch_size = to_2tuple(patch_size)
grid_size = to_2tuple(img_size // patch_size)
num_patches = grid_size[0] * grid_size[1]
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.embed_size = embed_size
embed_size_per_partition = divide(embed_size, self.depth)
embed_size_per_partition = embed_size // self.depth
self.flatten = flatten
self.weight = nn.Parameter(
@@ -725,7 +775,7 @@ class PatchEmbedding3D(ParallelLayer):
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attributes()
@@ -737,8 +787,7 @@ class PatchEmbedding3D(ParallelLayer):
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.input_parallel_mode)
grad = all_reduce(grad, self.weight_parallel_mode)
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None:
@@ -749,14 +798,10 @@ class PatchEmbedding3D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0]
broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode)
broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
self.bias.register_hook(self._sync_grad_hook)
@@ -850,11 +895,12 @@ class PatchEmbedding3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
@@ -906,7 +952,8 @@ class Embedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
@@ -924,13 +971,18 @@ class Embedding3D(ParallelLayer):
def _set_tensor_parallel_attributes(self) -> None:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
return grad
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight,
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
@@ -981,11 +1033,10 @@ class Embedding3D(ParallelLayer):
destination.update(local_state)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
input_ = split_batch_3d(input_,
input_parallel_mode=self.input_parallel_mode,
weight_parallel_mode=self.weight_parallel_mode)
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output
@@ -1039,7 +1090,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)