Hotfix/Colossalai layers (#92)

* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス
2021-12-29 23:32:10 +08:00
committed by GitHub
parent 0fedef4f3c
commit 01a80cd86d
71 changed files with 1033 additions and 773 deletions

View File

@@ -1,3 +1,9 @@
from .colossalai_layer import *
from .fused_bias_gelu import bias_gelu_impl
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
from .utils import *
from .vanilla import *
from .wrapper import *

View File

@@ -1,231 +0,0 @@
import math
from typing import Callable, Optional
from colossalai.utils import get_current_device
from torch import dtype, nn
from torch.nn.modules.activation import *
from torch.nn.modules.adaptive import *
from torch.nn.modules.batchnorm import *
from torch.nn.modules.channelshuffle import *
from torch.nn.modules.conv import *
from torch.nn.modules.distance import *
from torch.nn.modules.dropout import *
from torch.nn.modules.flatten import *
from torch.nn.modules.fold import *
from torch.nn.modules.instancenorm import *
from torch.nn.modules.linear import *
from torch.nn.modules.normalization import *
from torch.nn.modules.padding import *
from torch.nn.modules.pixelshuffle import *
from torch.nn.modules.pooling import *
from torch.nn.modules.rnn import *
from torch.nn.modules.sparse import *
from torch.nn.modules.transformer import *
from torch.nn.modules.upsampling import *
from .. import init as init
from .vanilla import *
from .parallel_1d import *
from .parallel_2d import *
from .parallel_2p5d import *
from .parallel_3d import *
from .parallel_sequence import *
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier,
'1d': VanillaClassifier,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
_parallel_embedding = {'3d': Embedding3D}
_parallel_patchembedding = {
None: VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
'3d': PatchEmbedding3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None,
**kwargs) -> None:
super().__init__()
if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)
class Embedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
tensor_parallel: Optional[str] = None,
*args,
**kwargs) -> None:
super().__init__()
if tensor_parallel in [None, '1d']:
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else:
self.embed = _parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property
def weight(self):
return self.embed.weight
def forward(self, *args):
return self.embed(*args)
class PatchEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.embed = _parallel_patchembedding[tensor_parallel](
img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer,
)
@property
def weight(self):
return self.embed.weight
@property
def bias(self):
return self.embed.bias
@property
def pos_embed(self):
return self.embed.pos_embed
@property
def cls_token(self):
return self.embed.cls_token
def forward(self, *args):
return self.embed(*args)
class Classifier(nn.Module):
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel: Optional[str] = None) -> None:
super().__init__()
self.layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)

View File

@@ -0,0 +1,7 @@
from ._utils import split_batch
from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear
from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']

View File

@@ -0,0 +1,19 @@
from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_tensor_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
def split_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, (tuple, list)):
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_

View File

@@ -0,0 +1,23 @@
from contextlib import nullcontext
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import conditional_context
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
class Dropout(nn.Module):
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super().__init__()
self.tensor_parallel = get_tensor_parallel_mode()
if self.tensor_parallel == '1d':
self.drop = Dropout1D(p, inplace)
else:
self.drop = nn.Dropout(p, inplace)
def forward(self, *args):
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
with cm:
return self.drop(*args)

View File

@@ -0,0 +1,107 @@
import math
from typing import Callable, Optional
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
_parallel_patchembedding = {
'None': VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
'3d': PatchEmbedding3D
}
class Embedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else:
self.embed = _parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property
def weight(self):
return self.embed.weight
def forward(self, *args):
return self.embed(*args)
class PatchEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
self.embed = _parallel_patchembedding[tensor_parallel](
img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer,
)
@property
def weight(self):
return self.embed.weight
@property
def bias(self):
return self.embed.bias
@property
def pos_embed(self):
return self.embed.pos_embed
@property
def cls_token(self):
return self.embed.cls_token
def forward(self, *args):
return self.embed(*args)

View File

@@ -0,0 +1,97 @@
import math
from typing import Callable, Optional
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
'None': VanillaClassifier,
'1d': Classifier1D,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class Classifier(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
) -> None:
super().__init__()
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from colossalai.utils import get_current_device
from torch import nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)

View File

@@ -1,35 +0,0 @@
# adapted from Megatron-LM
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
import torch
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply

View File

@@ -1,4 +1,4 @@
from .layers import Linear1D_Col, Linear1D_Row
from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D']
__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D']

View File

@@ -1,12 +1,21 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import torch
import torch.distributed as dist
from colossalai.constants import PARALLEL_INPUT_1D
from colossalai.core import global_context as gpc
from .._common_utils import divide
from ..utils import divide
def set_parallel_input(input_parallel: bool):
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else ''
def get_parallel_input():
return bool(os.environ[PARALLEL_INPUT_1D])
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):

View File

@@ -3,10 +3,10 @@
import math
import numbers
from contextlib import nullcontext
from typing import Callable, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed
@@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from torch import Tensor
from torch import Tensor, dtype
from torch.nn.parameter import Parameter
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._operation import FusedLayerNormAffineFunction1D
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward)
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
split_forward_gather_backward)
@LAYERS.register_module
class Linear1D(torch.nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
parallel_input = get_parallel_input()
if not parallel_input:
self.layer = Linear1D_Col(in_features,
out_features,
bias=bias,
dtype=dtype,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
else:
self.layer = Linear1D_Row(in_features,
out_features,
bias=bias,
dtype=dtype,
parallel_input=parallel_input,
skip_bias_add=skip_bias_add,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, input_: Tensor) -> Tensor:
return self.layer(input_)
@LAYERS.register_module
class Classifier1D(ParallelLayer):
"""RowLinear with given weight"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
if bias:
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self):
if self.has_weight:
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
if self.parallel_input:
input_ = input_
else:
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
output_parallel = F.linear(input_, self.weight)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
output = output + self.bias
return output
@LAYERS.register_module
@@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer):
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(True)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
@@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer):
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
@@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
@LAYERS.register_module
class Embedding1D(ParallelLayer):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
return output
@LAYERS.register_module
class Dropout1D(ParallelLayer):
def __init__(self, p: float = 0.5, inplace: bool = False):
super().__init__()
self.parallel_input = get_parallel_input()
self.p = p
self.inplace = inplace
def forward(self, input_: Tensor) -> Tensor:
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR)
with cm:
output = F.dropout(input_, self.p, self.training, self.inplace)
return output

View File

@@ -1,6 +1,6 @@
from ._operation import reduce_by_batch_2d, split_batch_2d
from ._operation import reduce_by_batch_2d, split_tensor_2d
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
__all__ = [
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
]

View File

@@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple
import torch
import torch.distributed as dist
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
@@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
@@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2d(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
return input_
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
return output / reduce_size
return output
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
return input_.clone()
def forward(ctx, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output
def backward(ctx, output_grad):
if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None
else:
return output_grad, None

View File

@@ -13,9 +13,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d)
from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d
from ._utils import assert_summa_initialization, get_summa_dim_from_env
@@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer):
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
@@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer):
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2d(input_)
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

View File

@@ -1,7 +1,7 @@
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
__all__ = [
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D'
]

View File

@@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode)
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
@@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function):
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1])).contiguous()
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1])).contiguous()
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait()
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
for op in [op_a, op_b]:
op.wait()
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list = [torch.empty_like(A) for _ in range(2)]
B_list = [torch.empty_like(B) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opb = [None] * 2
A_list[0].copy_(A)
B_list[0].copy_(B)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(tesseract_dim):
src_a = i + tesseract_dim * row_rank
src_b = i + tesseract_dim * col_rank
src_a = src_a % tesseract_dim
src_b = src_b % tesseract_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
if i != tesseract_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + tesseract_dim,
group=col_group,
async_op=True)
if opa[cur] is not None:
opa[cur].wait()
if opb[cur] is not None:
opb[cur].wait()
torch.addmm(C, A_list[cur], B_list[cur], out=C)
cur = 1 - cur
src_a += 1
src_b += tesseract_dim
out = C.reshape(out_shape)
if ctx:
@@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
B_temp = B.clone()
src_b = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
# use circular buffer to store the communication tensor
# 2 is enough for all cases
B_list = [torch.empty_like(B) for _ in range(2)]
C_list = [torch.empty_like(C) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2
opr = [None] * 2
B_list[0].copy_(B)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(tesseract_dim):
if i != tesseract_dim - 1:
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + tesseract_dim,
group=col_group,
async_op=True)
if opr[cur] is not None:
opr[cur].wait()
if i - 2 == col_rank:
C.copy_(C_list[cur])
if opb[cur] is not None:
opb[cur].wait()
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
cur = 1 - cur
src_b += tesseract_dim
src_c += 1
for op in opr:
op.wait()
if tesseract_dim - 2 == col_rank:
C.copy_(C_list[cur])
if tesseract_dim - 1 == col_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape)
if ctx:
@@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(tesseract_dim):
A_temp = A.clone()
src_a = i + row_rank * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + i * tesseract_dim + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list = [torch.empty_like(A) for _ in range(2)]
C_list = [torch.empty_like(C) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opr = [None] * 2
A_list[0].copy_(A)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
cur = 0
for i in range(tesseract_dim):
if i != tesseract_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
if opr[cur] is not None:
opr[cur].wait()
if i - 2 == row_rank:
C.copy_(C_list[cur])
if opa[cur] is not None:
opa[cur].wait()
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
cur = 1 - cur
src_a += 1
src_c += tesseract_dim
for op in opr:
op.wait()
if tesseract_dim - 2 == row_rank:
C.copy_(C_list[cur])
if tesseract_dim - 1 == row_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape)
if ctx:
@@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
src_rank = col_rank + dep_rank * (
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
@@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function):
return grad, None, None
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
@@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2p5d(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
return input_
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
return output / reduce_size
return output
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
return input_.clone()
def forward(ctx, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
return grad_output
def backward(ctx, output_grad):
if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None
else:
return output_grad, None

View File

@@ -13,10 +13,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..base_layer import ParallelLayer
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d,
split_batch_2p5d)
from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d)
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
@@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer):
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_size = embed_size
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2)
self.embed_size_per_partition = embed_size // self.tesseract_dim**2
with seed(ParallelMode.TENSOR):
self.weight = Parameter(
@@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer):
self._set_tensor_parallel_attribute()
def _set_tensor_parallel_attribute(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
with seed(ParallelMode.TENSOR):
@@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer):
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
@@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer):
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2)
embed_dim_per_partition = embedding_dim // self.tesseract_dim**2
self.padding_idx = padding_idx
self.embed_args = args
@@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer):
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
@@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer):
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_2p5d(input_)
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
@@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer):
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2)
self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2)
if weight is not None:
self.weight = weight
@@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer):
def _set_tensor_parallel_attributes(self):
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):

View File

@@ -1,6 +1,6 @@
from ._operation import reduce_by_batch_3d, split_batch_3d
from ._operation import reduce_by_batch_3d, split_tensor_3d
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
__all__ = [
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
]

View File

@@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None
def split_batch_3d(input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
dim: int = 0) -> Tensor:
def split_tensor_3d(input_: Tensor,
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
@@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor,
class reduce_by_batch_3d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor:
def forward(ctx,
input_: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode)
ctx.reduce_mean = reduce_mean
if reduce_mean:
reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode)
ctx.reduce_size = reduce_size
return output.clone() / reduce_size
return output.clone()
@staticmethod
@custom_bwd
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
return output_grad, None, None
if ctx.reduce_mean:
return output_grad / ctx.reduce_size, None, None, None
else:
return output_grad, None, None, None
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):

View File

@@ -17,9 +17,9 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import *
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group)
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
@LAYERS.register_module
@@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer):
self.pos_embed.register_hook(self._sync_grad_hook)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
@@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer):
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

View File

@@ -0,0 +1,7 @@
from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
__all__ = [
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
]

View File

@@ -2,11 +2,12 @@
# -*- encoding: utf-8 -*-
import collections.abc
import os
from itertools import repeat
import numpy as np
import torch
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE)
from colossalai.utils import checkpoint
from torch import Tensor, nn
@@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, NUM_PARTITIONS, num_partitions)
def get_tensor_parallel_mode():
return os.environ[TENSOR_PARALLEL_MODE]
# From PyTorch internals

View File

@@ -9,7 +9,7 @@ from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch import nn as nn
from .._common_utils import to_2tuple
from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False):