mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
updated tp layers
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
try:
|
||||
import fused_mix_prec_layer_norm_cuda
|
||||
@@ -43,3 +45,52 @@ class FusedLayerNormAffineFunction1D(torch.autograd.Function):
|
||||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
class LinearWithAsyncCommunication(torch.autograd.Function):
|
||||
"""
|
||||
Linear layer execution with asynchronous communication in backprop.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
ctx.save_for_backward(input_, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.parallel_mode = parallel_mode
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
|
||||
output = torch.matmul(input_, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
|
||||
total_input = input
|
||||
grad_input = grad_output.matmul(weight)
|
||||
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
|
||||
# Delay the start of weight gradient computation shortly (3us) to have
|
||||
# all-reduce scheduled first and have GPU resources allocated
|
||||
_ = torch.empty(1, device=grad_output.device) + 1
|
||||
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def linear_with_async_comm(input_, weight, bias, parallel_mode, async_grad_allreduce):
|
||||
return LinearWithAsyncCommunication.apply(input_, weight, bias, parallel_mode, async_grad_allreduce)
|
||||
|
@@ -20,12 +20,12 @@ from colossalai.utils.cuda import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..colossalai_layer._utils import ColossalaiModule
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||
split_forward_gather_backward)
|
||||
from ._operation import linear_with_async_comm
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -96,8 +96,25 @@ class LayerNorm1D(ColossalaiModule):
|
||||
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
|
||||
"""
|
||||
|
||||
_fast_ln_supported_sizes = [
|
||||
1024, 1536, 2048, 2304, 3072, 3840, 4096, 5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
|
||||
24576, 25600, 30720, 32768, 40960, 49152, 65536
|
||||
]
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||
norm = VanillaLayerNorm(normalized_shape, eps=eps, bias=bias, dtype=dtype)
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
||||
fast_ln_installed = False
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
fast_ln_installed = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if fast_ln_installed and normalized_shape in self._fast_ln_supported_sizes:
|
||||
norm = FastLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
else:
|
||||
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
super().__init__(norm)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
@@ -519,11 +536,12 @@ class Linear1D_Col(ParallelLayer):
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
# input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
# output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, ParallelMode.PARALLEL_1D, True)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
@@ -665,6 +683,7 @@ class Linear1D_Row(ParallelLayer):
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
|
@@ -9,7 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from ._utils import get_parallel_mode_from_env
|
||||
from ._utils import get_parallel_mode_from_env, push_async_grad
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
|
||||
|
||||
@@ -17,34 +17,27 @@ class _Linear3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = -1,
|
||||
output_dim: int = 0) -> Tensor:
|
||||
ctx.use_bias = bias is not None
|
||||
|
||||
input_ = all_gather(input_, input_dim, input_parallel_mode)
|
||||
weight = all_gather(weight, weight_dim, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
output = reduce_scatter(output, output_dim, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
output += bias
|
||||
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
weight_id: int,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.weight_id = weight_id
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
ctx.input_dim = input_dim
|
||||
ctx.weight_dim = weight_dim
|
||||
ctx.output_dim = output_dim
|
||||
|
||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||
weight = all_gather(weight, -1, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
output = reduce_scatter(output, 0, output_parallel_mode)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@@ -52,73 +45,70 @@ class _Linear3D(torch.autograd.Function):
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
output_grad = all_gather(output_grad, ctx.output_dim, ctx.output_parallel_mode)
|
||||
|
||||
async_ops = list()
|
||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
input_grad, op = reduce_scatter(input_grad, ctx.input_dim, ctx.input_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad, ctx.weight_dim, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
else:
|
||||
bias_grad = None
|
||||
input_op.wait()
|
||||
|
||||
for op in async_ops:
|
||||
if op is not None:
|
||||
op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
||||
return input_grad, weight_grad, None, None, None, None
|
||||
|
||||
|
||||
def linear_3d(input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_dim: int = 0,
|
||||
weight_dim: int = -1,
|
||||
output_dim: int = 0) -> Tensor:
|
||||
def linear_3d(
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
r"""Linear layer for 3D parallelism.
|
||||
|
||||
Args:
|
||||
input_ (:class:`torch.tensor`): input matrix.
|
||||
weight (:class:`torch.tensor`): matrix of weight.
|
||||
bias (:class:`torch.tensor`): matrix of bias.
|
||||
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
||||
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
||||
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
||||
input_dim (int, optional): dimension of input, defaults to 0.
|
||||
weight_dim (int, optional): dimension of weight, defaults to -1.
|
||||
output_dim (int, optional): dimension of output, defaults to 0.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode,
|
||||
input_dim, weight_dim, output_dim)
|
||||
return _Linear3D.apply(
|
||||
input_,
|
||||
weight,
|
||||
id(weight),
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
class _Classifier3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
weight_id: int,
|
||||
bias_id: Optional[int],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.weight_id = weight_id
|
||||
|
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
src_rank = gpc.get_ranks_in_group(input_parallel_mode)[gpc.get_local_rank(output_parallel_mode)]
|
||||
weight = broadcast(weight, src_rank, input_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
@@ -126,6 +116,7 @@ class _Classifier3D(torch.autograd.Function):
|
||||
output = all_reduce(output, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
ctx.bias_id = bias_id
|
||||
output += bias
|
||||
|
||||
ctx.src_rank = src_rank
|
||||
@@ -139,14 +130,12 @@ class _Classifier3D(torch.autograd.Function):
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
async_ops = list()
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
||||
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
else:
|
||||
weight_grad = None
|
||||
|
||||
@@ -154,21 +143,23 @@ class _Classifier3D(torch.autograd.Function):
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
async_ops.append(op)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
else:
|
||||
bias_grad = None
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight)
|
||||
|
||||
for op in async_ops:
|
||||
if op is not None:
|
||||
op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
||||
def classifier_3d(
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
r"""3D parallel classifier.
|
||||
|
||||
Args:
|
||||
@@ -183,16 +174,134 @@ def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode)
|
||||
return _Classifier3D.apply(
|
||||
input_,
|
||||
weight,
|
||||
bias,
|
||||
id(weight),
|
||||
id(bias) if bias is not None else None,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
class _VocabParallelClassifier3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
weight_id: int,
|
||||
bias_id: Optional[int],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.weight_id = weight_id
|
||||
|
||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
|
||||
ctx.save_for_backward(input_, weight)
|
||||
|
||||
output = torch.matmul(input_, weight)
|
||||
output = reduce_scatter(output, 0, output_parallel_mode)
|
||||
|
||||
if bias is not None:
|
||||
ctx.bias_id = bias_id
|
||||
output += bias
|
||||
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||
|
||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||
|
||||
weight_grad = torch.matmul(
|
||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
else:
|
||||
bias_grad = None
|
||||
|
||||
input_op.wait()
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
def vocab_parallel_classifier_3d(
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Optional[Tensor],
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
r"""3D vocab parallel classifier.
|
||||
|
||||
Args:
|
||||
input_ (:class:`torch.tensor`): input matrix.
|
||||
weight (:class:`torch.tensor`): matrix of weight.
|
||||
bias (:class:`torch.tensor`): matrix of bias.
|
||||
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
||||
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
||||
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _VocabParallelClassifier3D.apply(
|
||||
input_,
|
||||
weight,
|
||||
bias,
|
||||
id(weight),
|
||||
id(bias) if bias is not None else None,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
class _Layernorm3D(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(
|
||||
ctx,
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
weight_id: int,
|
||||
bias_id: int,
|
||||
normalized_shape: int,
|
||||
eps: float,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_x_weight_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
ctx.weight_id = weight_id
|
||||
ctx.bias_id = bias_id
|
||||
|
||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
mu = input_ - mean
|
||||
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
||||
@@ -201,15 +310,13 @@ class _Layernorm3D(torch.autograd.Function):
|
||||
ctx.save_for_backward(mu, sigma, weight)
|
||||
|
||||
z = mu / sigma
|
||||
output = weight * z
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
output = weight * z + bias
|
||||
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
|
||||
|
||||
return output
|
||||
|
||||
@@ -218,17 +325,14 @@ class _Layernorm3D(torch.autograd.Function):
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
mu, sigma, weight = ctx.saved_tensors
|
||||
with torch.no_grad():
|
||||
weight_grad = output_grad * mu / sigma
|
||||
if ctx.use_bias:
|
||||
bias_grad = output_grad
|
||||
weight_grad = torch.stack([bias_grad, weight_grad]).contiguous()
|
||||
else:
|
||||
bias_grad = None
|
||||
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[1:-1]))
|
||||
weight_grad = all_reduce(weight_grad, ctx.weight_parallel_mode)
|
||||
weight_grad = all_reduce(weight_grad, ctx.input_parallel_mode)
|
||||
if ctx.use_bias:
|
||||
bias_grad, weight_grad = weight_grad[0], weight_grad[1]
|
||||
|
||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
||||
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
|
||||
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
|
||||
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||
|
||||
dz = output_grad * weight
|
||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||
@@ -236,15 +340,22 @@ class _Layernorm3D(torch.autograd.Function):
|
||||
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
|
||||
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
||||
|
||||
input_grad = dz / sigma + dvar * 2 * mu / \
|
||||
ctx.normalized_shape + dmean / ctx.normalized_shape
|
||||
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
|
||||
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normalized_shape: int, eps: float,
|
||||
input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
def layernorm_3d(
|
||||
input_: Tensor,
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
normalized_shape: int,
|
||||
eps: float,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode,
|
||||
input_x_weight_parallel_mode: ParallelMode,
|
||||
) -> Tensor:
|
||||
r"""3D parallel Layernorm.
|
||||
|
||||
Args:
|
||||
@@ -265,8 +376,19 @@ def layernorm_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], normali
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode,
|
||||
output_parallel_mode)
|
||||
return _Layernorm3D.apply(
|
||||
input_,
|
||||
weight,
|
||||
bias,
|
||||
id(weight),
|
||||
id(bias),
|
||||
normalized_shape,
|
||||
eps,
|
||||
input_parallel_mode,
|
||||
weight_parallel_mode,
|
||||
output_parallel_mode,
|
||||
input_x_weight_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
|
||||
@@ -315,17 +437,12 @@ def split_batch_3d(input_: Tensor,
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
|
||||
"""
|
||||
dim_size = input_.size(dim)
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_world_size = gpc.get_world_size(weight_parallel_mode)
|
||||
input_world_size = gpc.get_world_size(input_parallel_mode)
|
||||
|
||||
assert dim_size % (input_world_size*weight_world_size) == 0, \
|
||||
f'The batch size ({dim_size}) is not a multiple of square of 3D depth ({input_world_size*weight_world_size}).'
|
||||
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
output = torch.chunk(input_, weight_world_size, dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, input_world_size, dim=dim)[gpc.get_local_rank(input_parallel_mode)].contiguous()
|
||||
return output
|
||||
@@ -464,47 +581,3 @@ def reduce_by_batch_3d(tensor: Tensor,
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
|
||||
|
||||
|
||||
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
|
||||
r"""broadcast weight from diagonal.
|
||||
|
||||
Args:
|
||||
input_ (:class:`torch.tensor`): input matrix.
|
||||
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
||||
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
||||
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
||||
|
||||
Note:
|
||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
|
||||
output_parallel_mode: ParallelMode) -> Tensor:
|
||||
ranks_in_group = gpc.get_ranks_in_group(input_parallel_mode)
|
||||
src_rank = ranks_in_group[gpc.get_local_rank(output_parallel_mode)]
|
||||
output = broadcast(input_, src_rank, input_parallel_mode)
|
||||
ctx.src_rank = src_rank
|
||||
ctx.input_parallel_mode = input_parallel_mode
|
||||
ctx.weight_parallel_mode = weight_parallel_mode
|
||||
ctx.output_parallel_mode = output_parallel_mode
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
input_grad = reduce(output_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||
input_grad = all_reduce(input_grad, ctx.weight_parallel_mode)
|
||||
else:
|
||||
input_grad = None
|
||||
return input_grad, None, None, None
|
||||
|
||||
|
||||
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
|
||||
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
|
||||
output_parallel_mode)
|
||||
|
@@ -1,8 +1,13 @@
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def get_depth_from_env() -> int:
|
||||
@@ -17,30 +22,17 @@ def get_depth_from_env() -> int:
|
||||
|
||||
|
||||
def get_parallel_mode_from_env(group):
|
||||
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \
|
||||
assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_X_WEIGHT_3D], \
|
||||
f'{group} is not valid for 3D tensor parallelism.'
|
||||
return getattr(env, group)
|
||||
|
||||
|
||||
def get_last_group(a, b):
|
||||
mapping = {
|
||||
ParallelMode.PARALLEL_3D_INPUT: 'A',
|
||||
ParallelMode.PARALLEL_3D_WEIGHT: 'B',
|
||||
ParallelMode.PARALLEL_3D_OUTPUT: 'C',
|
||||
}
|
||||
|
||||
res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
|
||||
|
||||
if res == 'A':
|
||||
return ParallelMode.PARALLEL_3D_INPUT
|
||||
elif res == 'B':
|
||||
return ParallelMode.PARALLEL_3D_WEIGHT
|
||||
elif res == 'C':
|
||||
return ParallelMode.PARALLEL_3D_OUTPUT
|
||||
|
||||
|
||||
def swap_in_out_group():
|
||||
env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d
|
||||
env.input_x_weight_group_3d, env.output_x_weight_group_3d = (
|
||||
env.output_x_weight_group_3d,
|
||||
env.input_x_weight_group_3d,
|
||||
)
|
||||
|
||||
|
||||
def dbg_check_shape(tensor: Tensor, shape: tuple):
|
||||
@@ -49,3 +41,60 @@ def dbg_check_shape(tensor: Tensor, shape: tuple):
|
||||
print(tensor.shape)
|
||||
assert tensor.shape == shape, \
|
||||
'{} does not match {}'.format(tensor.shape, shape)
|
||||
|
||||
|
||||
class AsyncGradientBucket(object):
|
||||
|
||||
def __init__(self):
|
||||
self.bucket = OrderedDict()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.bucket)
|
||||
|
||||
def push(self, async_op, grad_tensor, param_id):
|
||||
self.bucket[param_id] = tuple((async_op, grad_tensor))
|
||||
return torch.zeros_like(grad_tensor, dtype=grad_tensor.dtype, device=grad_tensor.device)
|
||||
|
||||
def pop(self, param_id):
|
||||
grad = None
|
||||
if param_id in self.bucket:
|
||||
op, grad = self.bucket.pop(param_id)
|
||||
if op is not None:
|
||||
op.wait()
|
||||
return grad
|
||||
|
||||
def synchronize(self, params):
|
||||
for p in params:
|
||||
i = id(p)
|
||||
if i in self.bucket:
|
||||
op, grad = self.bucket.pop(i)
|
||||
if op is not None:
|
||||
op.wait()
|
||||
p.grad.add_(grad)
|
||||
|
||||
|
||||
_async_grad_bucket = AsyncGradientBucket()
|
||||
|
||||
|
||||
def push_async_grad(op, grad, param_id):
|
||||
return _async_grad_bucket.push(op, grad, param_id)
|
||||
|
||||
|
||||
def pop_async_grad(param_id):
|
||||
return _async_grad_bucket.pop(param_id)
|
||||
|
||||
|
||||
def _async_grad_hook(grad, param_id):
|
||||
grad.add_(pop_async_grad(param_id))
|
||||
return grad
|
||||
|
||||
|
||||
def register_async_grad_hook(param):
|
||||
param.register_hook(partial(_async_grad_hook, param_id=id(param)))
|
||||
|
||||
|
||||
def synchronize(params=list()):
|
||||
_async_grad_bucket.synchronize(params)
|
||||
torch.cuda.default_stream().synchronize()
|
||||
if len(_async_grad_bucket) > 0:
|
||||
raise RuntimeError(f"{len(_async_grad_bucket)} asynchronous gradient(s) not collected.")
|
||||
|
@@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.communication import all_reduce, broadcast
|
||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
@@ -20,9 +20,9 @@ from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import (all_gather_tensor_3d, broadcast_weight_3d_from_diagonal, classifier_3d, layernorm_3d,
|
||||
linear_3d, reduce_scatter_tensor_3d, split_tensor_3d)
|
||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||
from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d,
|
||||
reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d)
|
||||
from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -45,7 +45,8 @@ class LayerNorm3D(ParallelLayer):
|
||||
super().__init__()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.normalized_shape = normalized_shape
|
||||
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
|
||||
@@ -58,6 +59,7 @@ class LayerNorm3D(ParallelLayer):
|
||||
else:
|
||||
self.bias = None
|
||||
self.variance_epsilon = eps
|
||||
self.reset_parameters()
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
@@ -67,8 +69,10 @@ class LayerNorm3D(ParallelLayer):
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.ones_()(self.weight)
|
||||
register_async_grad_hook(self.weight)
|
||||
if self.bias is not None:
|
||||
init.zeros_()(self.bias)
|
||||
register_async_grad_hook(self.bias)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
@@ -134,8 +138,17 @@ class LayerNorm3D(ParallelLayer):
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
|
||||
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
|
||||
return layernorm_3d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.normalized_shape,
|
||||
self.variance_epsilon,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
self.input_x_weight_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -161,6 +174,7 @@ class Linear3D(ParallelLayer):
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
@@ -168,8 +182,10 @@ class Linear3D(ParallelLayer):
|
||||
self.out_features = out_features
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(out_features, self.depth**2)
|
||||
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||
@@ -194,18 +210,23 @@ class Linear3D(ParallelLayer):
|
||||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> Tensor:
|
||||
grad = all_reduce(grad.clone(), self.output_x_weight_parallel_mode)
|
||||
return grad
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
register_async_grad_hook(self.weight)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode)
|
||||
self.bias.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
@@ -324,8 +345,20 @@ class Linear3D(ParallelLayer):
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
output = linear_3d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
)
|
||||
|
||||
if not self.skip_bias_add:
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -360,7 +393,7 @@ class Classifier3D(ParallelLayer):
|
||||
self.num_classes = num_classes
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
|
||||
@@ -386,19 +419,17 @@ class Classifier3D(ParallelLayer):
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
|
||||
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.weight, gpc.get_ranks_in_group(self.weight_parallel_mode)[0], self.weight_parallel_mode)
|
||||
|
||||
register_async_grad_hook(self.weight)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], ParallelMode.TENSOR)
|
||||
register_async_grad_hook(self.bias)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
@@ -468,8 +499,14 @@ class Classifier3D(ParallelLayer):
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
return classifier_3d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -504,7 +541,8 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
self.num_classes = num_classes
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
|
||||
self.depth = get_depth_from_env()
|
||||
self.in_features_per_partition = divide(in_features, self.depth)
|
||||
self.out_features_per_partition = divide(num_classes, self.depth**2)
|
||||
@@ -544,12 +582,14 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
|
||||
register_async_grad_hook(self.weight)
|
||||
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias,
|
||||
gpc.get_ranks_in_group(self.output_x_weight_parallel_mode)[0],
|
||||
self.output_x_weight_parallel_mode)
|
||||
register_async_grad_hook(self.bias)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
@@ -668,8 +708,14 @@ class VocabParallelClassifier3D(ParallelLayer):
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
return vocab_parallel_classifier_3d(
|
||||
input_,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.input_parallel_mode,
|
||||
self.weight_parallel_mode,
|
||||
self.output_parallel_mode,
|
||||
)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -708,12 +754,16 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
grid_size = to_2tuple(img_size // patch_size)
|
||||
num_patches = grid_size[0] * grid_size[1]
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.embed_size = embed_size
|
||||
embed_size_per_partition = divide(embed_size, self.depth)
|
||||
embed_size_per_partition = embed_size // self.depth
|
||||
self.flatten = flatten
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
@@ -725,7 +775,7 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
@@ -737,8 +787,7 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> Tensor:
|
||||
grad = all_reduce(grad.clone(), self.input_parallel_mode)
|
||||
grad = all_reduce(grad, self.weight_parallel_mode)
|
||||
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
|
||||
return grad
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer) -> None:
|
||||
@@ -749,14 +798,10 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
input_src_rank = gpc.get_ranks_in_group(self.input_parallel_mode)[0]
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.weight, input_src_rank, self.input_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
|
||||
src_rank = gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0]
|
||||
broadcast(self.weight, src_rank, self.input_x_weight_parallel_mode)
|
||||
broadcast(self.bias, src_rank, self.input_x_weight_parallel_mode)
|
||||
broadcast(self.pos_embed, src_rank, self.input_x_weight_parallel_mode)
|
||||
|
||||
self.weight.register_hook(self._sync_grad_hook)
|
||||
self.bias.register_hook(self._sync_grad_hook)
|
||||
@@ -850,11 +895,12 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||
input_ = split_batch_3d(input_,
|
||||
input_parallel_mode=self.input_parallel_mode,
|
||||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||
if self.flatten:
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||
output = torch.cat((cls_token, output), dim=1)
|
||||
@@ -906,7 +952,8 @@ class Embedding3D(ParallelLayer):
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.input_x_weight_parallel_mode = get_parallel_mode_from_env(INPUT_X_WEIGHT_3D)
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
@@ -924,13 +971,18 @@ class Embedding3D(ParallelLayer):
|
||||
def _set_tensor_parallel_attributes(self) -> None:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.depth)
|
||||
|
||||
def _sync_grad_hook(self, grad) -> Tensor:
|
||||
grad = all_reduce(grad.clone(), self.input_x_weight_parallel_mode)
|
||||
return grad
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
|
||||
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.weight,
|
||||
gpc.get_ranks_in_group(self.input_x_weight_parallel_mode)[0], self.input_x_weight_parallel_mode)
|
||||
self.weight.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
@@ -981,11 +1033,10 @@ class Embedding3D(ParallelLayer):
|
||||
destination.update(local_state)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
|
||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||
weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode,
|
||||
self.output_parallel_mode)
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
input_ = split_batch_3d(input_,
|
||||
input_parallel_mode=self.input_parallel_mode,
|
||||
weight_parallel_mode=self.weight_parallel_mode)
|
||||
output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
return output
|
||||
|
||||
@@ -1039,7 +1090,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
||||
self.depth = get_depth_from_env()
|
||||
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
|
||||
self.output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth**2)
|
||||
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
|
||||
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
|
||||
|
Reference in New Issue
Block a user