mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -1,3 +1,9 @@
|
||||
from .colossalai_layer import *
|
||||
from .fused_bias_gelu import bias_gelu_impl
|
||||
from .parallel_1d import *
|
||||
from .parallel_2d import *
|
||||
from .parallel_2p5d import *
|
||||
from .parallel_3d import *
|
||||
from .parallel_sequence import *
|
||||
from .utils import *
|
||||
from .vanilla import *
|
||||
from .wrapper import *
|
||||
|
@@ -1,231 +0,0 @@
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
from torch.nn.modules.activation import *
|
||||
from torch.nn.modules.adaptive import *
|
||||
from torch.nn.modules.batchnorm import *
|
||||
from torch.nn.modules.channelshuffle import *
|
||||
from torch.nn.modules.conv import *
|
||||
from torch.nn.modules.distance import *
|
||||
from torch.nn.modules.dropout import *
|
||||
from torch.nn.modules.flatten import *
|
||||
from torch.nn.modules.fold import *
|
||||
from torch.nn.modules.instancenorm import *
|
||||
from torch.nn.modules.linear import *
|
||||
from torch.nn.modules.normalization import *
|
||||
from torch.nn.modules.padding import *
|
||||
from torch.nn.modules.pixelshuffle import *
|
||||
from torch.nn.modules.pooling import *
|
||||
from torch.nn.modules.rnn import *
|
||||
from torch.nn.modules.sparse import *
|
||||
from torch.nn.modules.transformer import *
|
||||
from torch.nn.modules.upsampling import *
|
||||
|
||||
from .. import init as init
|
||||
|
||||
from .vanilla import *
|
||||
from .parallel_1d import *
|
||||
from .parallel_2d import *
|
||||
from .parallel_2p5d import *
|
||||
from .parallel_3d import *
|
||||
from .parallel_sequence import *
|
||||
|
||||
_parallel_linear = {'1d_col': Linear1D_Col, '1d_row': Linear1D_Row, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||
|
||||
_parallel_classifier = {
|
||||
None: VanillaClassifier,
|
||||
'1d': VanillaClassifier,
|
||||
'2d': Classifier2D,
|
||||
'2.5d': Classifier2p5D,
|
||||
'3d': Classifier3D
|
||||
}
|
||||
|
||||
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||
|
||||
_parallel_embedding = {'3d': Embedding3D}
|
||||
|
||||
_parallel_patchembedding = {
|
||||
None: VanillaPatchEmbedding,
|
||||
'1d': VanillaPatchEmbedding,
|
||||
'2d': PatchEmbedding2D,
|
||||
'2.5d': PatchEmbedding2p5D,
|
||||
'3d': PatchEmbedding3D
|
||||
}
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
tensor_parallel: Optional[str] = None,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel is None:
|
||||
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
|
||||
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
||||
if bias:
|
||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
||||
else:
|
||||
self.layer = _parallel_linear[tensor_parallel](
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None, tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||
else:
|
||||
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.norm.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.norm.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.norm(*args)
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
tensor_parallel: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
if tensor_parallel in [None, '1d']:
|
||||
self.embed = nn.Embedding(num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
*args,
|
||||
**kwargs)
|
||||
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||
else:
|
||||
self.embed = _parallel_embedding[tensor_parallel](
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_(),
|
||||
tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.embed = _parallel_patchembedding[tensor_parallel](
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.embed.bias
|
||||
|
||||
@property
|
||||
def pos_embed(self):
|
||||
return self.embed.pos_embed
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
return self.embed.cls_token
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: nn.Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
tensor_parallel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.layer = _parallel_classifier[tensor_parallel](
|
||||
in_features,
|
||||
num_classes,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
7
colossalai/nn/layer/colossalai_layer/__init__.py
Normal file
7
colossalai/nn/layer/colossalai_layer/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from ._utils import split_batch
|
||||
from .dropout import Dropout
|
||||
from .embedding import Embedding, PatchEmbedding
|
||||
from .linear import Classifier, Linear
|
||||
from .normalization import LayerNorm
|
||||
|
||||
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']
|
19
colossalai/nn/layer/colossalai_layer/_utils.py
Normal file
19
colossalai/nn/layer/colossalai_layer/_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from torch import Tensor
|
||||
|
||||
from ..parallel_2d._operation import split_tensor_2d
|
||||
from ..parallel_2p5d._operation import split_tensor_2p5d
|
||||
from ..parallel_3d._operation import split_tensor_3d
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
|
||||
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
|
||||
|
||||
|
||||
def split_batch(input_) -> Tensor:
|
||||
tensor_parallel_mode = get_tensor_parallel_mode()
|
||||
if tensor_parallel_mode in _parallel_split_batch:
|
||||
if isinstance(input_, (tuple, list)):
|
||||
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
|
||||
else:
|
||||
return _parallel_split_batch[tensor_parallel_mode](input_)
|
||||
else:
|
||||
return input_
|
23
colossalai/nn/layer/colossalai_layer/dropout.py
Normal file
23
colossalai/nn/layer/colossalai_layer/dropout.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch.nn as nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import conditional_context
|
||||
|
||||
from ..parallel_1d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
|
||||
|
||||
class Dropout(nn.Module):
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.tensor_parallel = get_tensor_parallel_mode()
|
||||
if self.tensor_parallel == '1d':
|
||||
self.drop = Dropout1D(p, inplace)
|
||||
else:
|
||||
self.drop = nn.Dropout(p, inplace)
|
||||
|
||||
def forward(self, *args):
|
||||
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
|
||||
with cm:
|
||||
return self.drop(*args)
|
107
colossalai/nn/layer/colossalai_layer/embedding.py
Normal file
107
colossalai/nn/layer/colossalai_layer/embedding.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
|
||||
|
||||
_parallel_patchembedding = {
|
||||
'None': VanillaPatchEmbedding,
|
||||
'1d': VanillaPatchEmbedding,
|
||||
'2d': PatchEmbedding2D,
|
||||
'2.5d': PatchEmbedding2p5D,
|
||||
'3d': PatchEmbedding3D
|
||||
}
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel == 'None':
|
||||
self.embed = nn.Embedding(num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
*args,
|
||||
**kwargs)
|
||||
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||
else:
|
||||
self.embed = _parallel_embedding[tensor_parallel](
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
self.embed = _parallel_patchembedding[tensor_parallel](
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.embed.bias
|
||||
|
||||
@property
|
||||
def pos_embed(self):
|
||||
return self.embed.pos_embed
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
return self.embed.cls_token
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
97
colossalai/nn/layer/colossalai_layer/linear.py
Normal file
97
colossalai/nn/layer/colossalai_layer/linear.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||
|
||||
_parallel_classifier = {
|
||||
'None': VanillaClassifier,
|
||||
'1d': Classifier1D,
|
||||
'2d': Classifier2D,
|
||||
'2.5d': Classifier2p5D,
|
||||
'3d': Classifier3D
|
||||
}
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel == 'None':
|
||||
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
|
||||
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
||||
if bias:
|
||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
||||
else:
|
||||
self.layer = _parallel_linear[tensor_parallel](
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: nn.Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
|
||||
in_features,
|
||||
num_classes,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
35
colossalai/nn/layer/colossalai_layer/normalization.py
Normal file
35
colossalai/nn/layer/colossalai_layer/normalization.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel in ['None', '1d']:
|
||||
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||
else:
|
||||
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.norm.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.norm.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.norm(*args)
|
@@ -1,35 +0,0 @@
|
||||
# adapted from Megatron-LM
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/megatron/model/fused_bias_gelu.py
|
||||
|
||||
import torch
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, bias, y):
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff*g
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(bias, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_back(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
@@ -1,4 +1,4 @@
|
||||
from .layers import Linear1D_Col, Linear1D_Row
|
||||
from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row
|
||||
from .layers import MixedFusedLayerNorm1D as LayerNorm1D
|
||||
|
||||
__all__ = ['Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D']
|
||||
__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D']
|
||||
|
@@ -1,12 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.constants import PARALLEL_INPUT_1D
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from .._common_utils import divide
|
||||
from ..utils import divide
|
||||
|
||||
|
||||
def set_parallel_input(input_parallel: bool):
|
||||
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else ''
|
||||
|
||||
|
||||
def get_parallel_input():
|
||||
return bool(os.environ[PARALLEL_INPUT_1D])
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
|
||||
|
@@ -3,10 +3,10 @@
|
||||
|
||||
import math
|
||||
import numbers
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from colossalai.communication import broadcast
|
||||
from colossalai.context import ParallelMode, seed
|
||||
@@ -14,13 +14,122 @@ from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import Tensor
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .._common_utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ._operation import FusedLayerNormAffineFunction1D
|
||||
from ._utils import (gather_forward_split_backward, reduce_grad, reduce_input, split_forward_gather_backward)
|
||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||
split_forward_gather_backward)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Linear1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
parallel_input = get_parallel_input()
|
||||
if not parallel_input:
|
||||
self.layer = Linear1D_Col(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
else:
|
||||
self.layer = Linear1D_Row(in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
parallel_input=parallel_input,
|
||||
skip_bias_add=skip_bias_add,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
return self.layer(input_)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Classifier1D(ParallelLayer):
|
||||
"""RowLinear with given weight"""
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.num_classes = num_classes
|
||||
self.parallel_input = get_parallel_input()
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(self.num_classes, self.input_size_per_partition, **factory_kwargs))
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.num_classes, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.num_classes
|
||||
if self.has_weight:
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
if self.has_weight:
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
input_ = input_
|
||||
else:
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
|
||||
output = output + self.bias
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -77,6 +186,7 @@ class Linear1D_Col(ParallelLayer):
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(True)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
@@ -158,6 +268,7 @@ class Linear1D_Row(ParallelLayer):
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
@@ -208,3 +319,68 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||
|
||||
def forward(self, input):
|
||||
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Embedding1D(ParallelLayer):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
self._set_tensor_parallel_attributes()
|
||||
set_parallel_input(False)
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class Dropout1D(ParallelLayer):
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False):
|
||||
super().__init__()
|
||||
self.parallel_input = get_parallel_input()
|
||||
self.p = p
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR)
|
||||
with cm:
|
||||
output = F.dropout(input_, self.p, self.training, self.inplace)
|
||||
return output
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from ._operation import reduce_by_batch_2d, split_batch_2d
|
||||
from ._operation import reduce_by_batch_2d, split_tensor_2d
|
||||
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D
|
||||
|
||||
__all__ = [
|
||||
'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
|
||||
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D'
|
||||
]
|
||||
|
@@ -2,7 +2,7 @@ from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter)
|
||||
from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
@@ -595,7 +595,9 @@ class SplitFirst(torch.autograd.Function):
|
||||
return grad, None, None
|
||||
|
||||
|
||||
def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
|
||||
|
||||
@@ -603,17 +605,28 @@ def split_batch_2d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
class reduce_by_batch_2d(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
||||
return input_
|
||||
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||
return output / reduce_size
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2D_COL))
|
||||
return input_.clone()
|
||||
def forward(ctx, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
|
||||
ctx.reduce_mean = reduce_mean
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2D_COL)
|
||||
ctx.reduce_size = reduce_size
|
||||
return output.clone() / reduce_size
|
||||
return output.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
def backward(ctx, output_grad):
|
||||
if ctx.reduce_mean:
|
||||
return output_grad / ctx.reduce_size, None
|
||||
else:
|
||||
return output_grad, None
|
||||
|
@@ -13,9 +13,9 @@ from colossalai.utils import get_current_device
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ..base_layer import ParallelLayer
|
||||
from ._operation import (Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d, split_batch_2d)
|
||||
from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d
|
||||
from ._utils import assert_summa_initialization, get_summa_dim_from_env
|
||||
|
||||
|
||||
@@ -257,8 +257,6 @@ class PatchEmbedding2D(ParallelLayer):
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
@@ -318,8 +316,6 @@ class Embedding2D(ParallelLayer):
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_2d(input_)
|
||||
|
||||
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from ._operation import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
|
||||
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D
|
||||
|
||||
__all__ = [
|
||||
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
|
||||
'Embedding2p5D'
|
||||
]
|
||||
|
@@ -22,7 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
|
||||
return gpc.get_local_rank(parallel_mode)
|
||||
|
||||
|
||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||
|
||||
@@ -120,30 +120,53 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1])).contiguous()
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1])).contiguous()
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode) - 1)]
|
||||
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode) - 1)]
|
||||
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
|
||||
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
|
||||
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
|
||||
op_a.wait()
|
||||
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
|
||||
for op in [op_a, op_b]:
|
||||
op.wait()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
A_list = [torch.empty_like(A) for _ in range(2)]
|
||||
B_list = [torch.empty_like(B) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opa = [None] * 2
|
||||
opb = [None] * 2
|
||||
|
||||
A_list[0].copy_(A)
|
||||
B_list[0].copy_(B)
|
||||
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
|
||||
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
src_a = i + tesseract_dim * row_rank
|
||||
src_b = i + tesseract_dim * col_rank
|
||||
src_a = src_a % tesseract_dim
|
||||
src_b = src_b % tesseract_dim
|
||||
A_temp = A_list[src_a]
|
||||
B_temp = B_list[src_b]
|
||||
torch.addmm(C, A_temp, B_temp, out=C)
|
||||
if i != tesseract_dim - 1:
|
||||
A_list[1 - cur].copy_(A)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
|
||||
B_list[1 - cur].copy_(B)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||
src=src_b + tesseract_dim,
|
||||
group=col_group,
|
||||
async_op=True)
|
||||
|
||||
if opa[cur] is not None:
|
||||
opa[cur].wait()
|
||||
if opb[cur] is not None:
|
||||
opb[cur].wait()
|
||||
|
||||
torch.addmm(C, A_list[cur], B_list[cur], out=C)
|
||||
cur = 1 - cur
|
||||
src_a += 1
|
||||
src_b += tesseract_dim
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
@@ -201,20 +224,55 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
||||
C_shape = (A.shape[0], B.shape[0])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
B_temp = B.clone()
|
||||
src_b = col_rank + i * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(B_temp, src=src_b, group=gpc.get_group(col_parallel_mode))
|
||||
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
|
||||
src_c = i + row_rank * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(C_temp, dst=src_c, group=gpc.get_group(row_parallel_mode))
|
||||
if i == col_rank:
|
||||
C = C_temp.clone()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
B_list = [torch.empty_like(B) for _ in range(2)]
|
||||
C_list = [torch.empty_like(C) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_b = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_c = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opb = [None] * 2
|
||||
opr = [None] * 2
|
||||
|
||||
B_list[0].copy_(B)
|
||||
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
if i != tesseract_dim - 1:
|
||||
B_list[1 - cur].copy_(B)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||
src=src_b + tesseract_dim,
|
||||
group=col_group,
|
||||
async_op=True)
|
||||
|
||||
if opr[cur] is not None:
|
||||
opr[cur].wait()
|
||||
if i - 2 == col_rank:
|
||||
C.copy_(C_list[cur])
|
||||
|
||||
if opb[cur] is not None:
|
||||
opb[cur].wait()
|
||||
|
||||
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
|
||||
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
|
||||
cur = 1 - cur
|
||||
src_b += tesseract_dim
|
||||
src_c += 1
|
||||
|
||||
for op in opr:
|
||||
op.wait()
|
||||
|
||||
if tesseract_dim - 2 == col_rank:
|
||||
C.copy_(C_list[cur])
|
||||
if tesseract_dim - 1 == col_rank:
|
||||
C.copy_(C_list[1 - cur])
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
@@ -272,20 +330,52 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
||||
C_shape = (A.shape[-1], B.shape[-1])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
A_temp = A.clone()
|
||||
src_a = i + row_rank * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(A_temp, src=src_a, group=get_parallel_group(row_parallel_mode))
|
||||
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
|
||||
src_c = col_rank + i * tesseract_dim + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(C_temp, dst=src_c, group=get_parallel_group(col_parallel_mode))
|
||||
if i == row_rank:
|
||||
C = C_temp.clone()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
A_list = [torch.empty_like(A) for _ in range(2)]
|
||||
C_list = [torch.empty_like(C) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_a = tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_c = col_rank + tesseract_dim ** 2 * dep_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opa = [None] * 2
|
||||
opr = [None] * 2
|
||||
|
||||
A_list[0].copy_(A)
|
||||
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(tesseract_dim):
|
||||
if i != tesseract_dim - 1:
|
||||
A_list[1 - cur].copy_(A)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
|
||||
|
||||
if opr[cur] is not None:
|
||||
opr[cur].wait()
|
||||
if i - 2 == row_rank:
|
||||
C.copy_(C_list[cur])
|
||||
|
||||
if opa[cur] is not None:
|
||||
opa[cur].wait()
|
||||
|
||||
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
|
||||
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
|
||||
cur = 1 - cur
|
||||
src_a += 1
|
||||
src_c += tesseract_dim
|
||||
|
||||
for op in opr:
|
||||
op.wait()
|
||||
|
||||
if tesseract_dim - 2 == row_rank:
|
||||
C.copy_(C_list[cur])
|
||||
if tesseract_dim - 1 == row_rank:
|
||||
C.copy_(C_list[1 - cur])
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
@@ -333,8 +423,7 @@ class Add_Bias_2p5D(torch.autograd.Function):
|
||||
bias_temp = bias.clone()
|
||||
else:
|
||||
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
||||
src_rank = col_rank + dep_rank * (
|
||||
tesseract_dim ** 2) + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
src_rank = col_rank + dep_rank * tesseract_dim ** 2 + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
|
||||
|
||||
@@ -469,7 +558,9 @@ class SplitFirst(torch.autograd.Function):
|
||||
return grad, None, None
|
||||
|
||||
|
||||
def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
|
||||
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
|
||||
|
||||
@@ -477,17 +568,28 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
|
||||
class reduce_by_batch_2p5d(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
|
||||
return input_
|
||||
def symbolic(graph, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||
return output / reduce_size
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_):
|
||||
dist.all_reduce(input_, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_COL))
|
||||
return input_.clone()
|
||||
def forward(ctx, input_, reduce_mean: bool = False):
|
||||
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
|
||||
ctx.reduce_mean = reduce_mean
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL)
|
||||
ctx.reduce_size = reduce_size
|
||||
return output.clone() / reduce_size
|
||||
return output.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
def backward(ctx, output_grad):
|
||||
if ctx.reduce_mean:
|
||||
return output_grad / ctx.reduce_size, None
|
||||
else:
|
||||
return output_grad, None
|
||||
|
@@ -13,10 +13,9 @@ from colossalai.utils import get_current_device
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ..base_layer import ParallelLayer
|
||||
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d,
|
||||
split_batch_2p5d)
|
||||
from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d)
|
||||
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env)
|
||||
|
||||
|
||||
@@ -231,7 +230,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
self.embed_size = embed_size
|
||||
self.embed_size_per_partition = embed_size // (self.tesseract_dep * self.tesseract_dim**2)
|
||||
self.embed_size_per_partition = embed_size // self.tesseract_dim**2
|
||||
|
||||
with seed(ParallelMode.TENSOR):
|
||||
self.weight = Parameter(
|
||||
@@ -251,10 +250,10 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||
self._set_tensor_parallel_attribute()
|
||||
|
||||
def _set_tensor_parallel_attribute(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.cls_token, self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
|
||||
with seed(ParallelMode.TENSOR):
|
||||
@@ -269,8 +268,6 @@ class PatchEmbedding2p5D(ParallelLayer):
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
|
||||
input_ = split_batch_2p5d(input_)
|
||||
|
||||
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
@@ -303,7 +300,7 @@ class Embedding2p5D(ParallelLayer):
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
embed_dim_per_partition = embedding_dim // (self.tesseract_dep * self.tesseract_dim**2)
|
||||
embed_dim_per_partition = embedding_dim // self.tesseract_dim**2
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
@@ -316,7 +313,7 @@ class Embedding2p5D(ParallelLayer):
|
||||
self._set_tensor_parallel_attributes()
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
@@ -330,8 +327,6 @@ class Embedding2p5D(ParallelLayer):
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_2p5d(input_)
|
||||
|
||||
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
@@ -359,7 +354,7 @@ class Classifier2p5D(ParallelLayer):
|
||||
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
|
||||
|
||||
# partitioning dimension
|
||||
self.input_size_per_partition = divide(self.in_features, self.tesseract_dep * self.tesseract_dim**2)
|
||||
self.input_size_per_partition = divide(self.in_features, self.tesseract_dim**2)
|
||||
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
@@ -378,7 +373,7 @@ class Classifier2p5D(ParallelLayer):
|
||||
|
||||
def _set_tensor_parallel_attributes(self):
|
||||
if self.has_weight:
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dep * self.tesseract_dim**2)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
with seed(ParallelMode.TENSOR):
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from ._operation import reduce_by_batch_3d, split_batch_3d
|
||||
from ._operation import reduce_by_batch_3d, split_tensor_3d
|
||||
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D
|
||||
|
||||
__all__ = [
|
||||
'reduce_by_batch_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
|
||||
'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D'
|
||||
]
|
||||
|
@@ -175,10 +175,12 @@ class layernorm_3d(torch.autograd.Function):
|
||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||
|
||||
|
||||
def split_batch_3d(input_: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
dim: int = 0) -> Tensor:
|
||||
def split_tensor_3d(input_: Tensor,
|
||||
dim: int = 0,
|
||||
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
|
||||
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
|
||||
if input_.size(dim) <= 1:
|
||||
return input_
|
||||
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
|
||||
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
|
||||
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
|
||||
@@ -189,15 +191,27 @@ def split_batch_3d(input_: Tensor,
|
||||
class reduce_by_batch_3d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx,
|
||||
input_: Tensor,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
reduce_mean: bool = False) -> Tensor:
|
||||
output = all_reduce(input_, input_parallel_mode)
|
||||
output = all_reduce(output, weight_parallel_mode)
|
||||
ctx.reduce_mean = reduce_mean
|
||||
if reduce_mean:
|
||||
reduce_size = gpc.get_world_size(input_parallel_mode) * gpc.get_world_size(weight_parallel_mode)
|
||||
ctx.reduce_size = reduce_size
|
||||
return output.clone() / reduce_size
|
||||
return output.clone()
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
return output_grad, None, None
|
||||
if ctx.reduce_mean:
|
||||
return output_grad / ctx.reduce_size, None, None, None
|
||||
else:
|
||||
return output_grad, None, None, None
|
||||
|
||||
|
||||
class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
|
||||
|
@@ -17,9 +17,9 @@ from colossalai.utils import get_current_device
|
||||
from torch import Tensor, dtype
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .._common_utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple)
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import *
|
||||
from ._utils import (get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group)
|
||||
from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_env, swap_in_out_group
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
@@ -241,8 +241,6 @@ class PatchEmbedding3D(ParallelLayer):
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
|
||||
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size)
|
||||
@@ -302,8 +300,6 @@ class Embedding3D(ParallelLayer):
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
input_ = split_batch_3d(input_, self.input_parallel_mode, self.weight_parallel_mode)
|
||||
|
||||
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode,
|
||||
self.weight_parallel_mode, self.output_parallel_mode)
|
||||
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
7
colossalai/nn/layer/utils/__init__.py
Normal file
7
colossalai/nn/layer/utils/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
|
||||
set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
|
||||
|
||||
__all__ = [
|
||||
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
|
||||
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
|
||||
]
|
@@ -2,11 +2,12 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import collections.abc
|
||||
import os
|
||||
from itertools import repeat
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
|
||||
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE)
|
||||
from colossalai.utils import checkpoint
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -59,6 +60,10 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
|
||||
setattr(param, NUM_PARTITIONS, num_partitions)
|
||||
|
||||
|
||||
def get_tensor_parallel_mode():
|
||||
return os.environ[TENSOR_PARALLEL_MODE]
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
|
||||
|
@@ -9,7 +9,7 @@ from colossalai.utils import get_current_device
|
||||
from torch import Tensor, dtype
|
||||
from torch import nn as nn
|
||||
|
||||
from .._common_utils import to_2tuple
|
||||
from ..utils import to_2tuple
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
|
Reference in New Issue
Block a user