Hotfix/Colossalai layers (#92)

* optimized 1d layer apis; reorganized nn.layer modules; fixed tests

* fixed 2.5d runtime issue

* reworked split batch, now called in trainer.schedule.load_batch

Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
アマデウス
2021-12-29 23:32:10 +08:00
committed by GitHub
parent 0fedef4f3c
commit 01a80cd86d
71 changed files with 1033 additions and 773 deletions

View File

@@ -0,0 +1,7 @@
from ._utils import split_batch
from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear
from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']

View File

@@ -0,0 +1,19 @@
from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_tensor_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
def split_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, (tuple, list)):
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
else:
return _parallel_split_batch[tensor_parallel_mode](input_)
else:
return input_

View File

@@ -0,0 +1,23 @@
from contextlib import nullcontext
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import conditional_context
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
class Dropout(nn.Module):
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super().__init__()
self.tensor_parallel = get_tensor_parallel_mode()
if self.tensor_parallel == '1d':
self.drop = Dropout1D(p, inplace)
else:
self.drop = nn.Dropout(p, inplace)
def forward(self, *args):
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
with cm:
return self.drop(*args)

View File

@@ -0,0 +1,107 @@
import math
from typing import Callable, Optional
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
_parallel_patchembedding = {
'None': VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
'3d': PatchEmbedding3D
}
class Embedding(nn.Module):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else:
self.embed = _parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property
def weight(self):
return self.embed.weight
def forward(self, *args):
return self.embed(*args)
class PatchEmbedding(nn.Module):
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
self.embed = _parallel_patchembedding[tensor_parallel](
img_size,
patch_size,
in_chans,
embed_size,
dtype=dtype,
flatten=flatten,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
position_embed_initializer=position_embed_initializer,
)
@property
def weight(self):
return self.embed.weight
@property
def bias(self):
return self.embed.bias
@property
def pos_embed(self):
return self.embed.pos_embed
@property
def cls_token(self):
return self.embed.cls_token
def forward(self, *args):
return self.embed(*args)

View File

@@ -0,0 +1,97 @@
import math
from typing import Callable, Optional
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
from colossalai.utils import get_current_device
from torch import dtype, nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
'None': VanillaClassifier,
'1d': Classifier1D,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
in_features,
out_features,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
**kwargs,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)
class Classifier(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
) -> None:
super().__init__()
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
return self.layer.weight
@property
def bias(self):
return self.layer.bias
def forward(self, *args):
return self.layer(*args)

View File

@@ -0,0 +1,35 @@
from typing import Optional
from colossalai.utils import get_current_device
from torch import nn
from ... import init as init
from ..parallel_1d import *
from ..parallel_2d import *
from ..parallel_2p5d import *
from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
class LayerNorm(nn.Module):
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
@property
def weight(self):
return self.norm.weight
@property
def bias(self):
return self.norm.bias
def forward(self, *args):
return self.norm(*args)