mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
Hotfix/Colossalai layers (#92)
* optimized 1d layer apis; reorganized nn.layer modules; fixed tests * fixed 2.5d runtime issue * reworked split batch, now called in trainer.schedule.load_batch Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
7
colossalai/nn/layer/colossalai_layer/__init__.py
Normal file
7
colossalai/nn/layer/colossalai_layer/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from ._utils import split_batch
|
||||
from .dropout import Dropout
|
||||
from .embedding import Embedding, PatchEmbedding
|
||||
from .linear import Classifier, Linear
|
||||
from .normalization import LayerNorm
|
||||
|
||||
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']
|
19
colossalai/nn/layer/colossalai_layer/_utils.py
Normal file
19
colossalai/nn/layer/colossalai_layer/_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from torch import Tensor
|
||||
|
||||
from ..parallel_2d._operation import split_tensor_2d
|
||||
from ..parallel_2p5d._operation import split_tensor_2p5d
|
||||
from ..parallel_3d._operation import split_tensor_3d
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
|
||||
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
|
||||
|
||||
|
||||
def split_batch(input_) -> Tensor:
|
||||
tensor_parallel_mode = get_tensor_parallel_mode()
|
||||
if tensor_parallel_mode in _parallel_split_batch:
|
||||
if isinstance(input_, (tuple, list)):
|
||||
return tuple(map(_parallel_split_batch[tensor_parallel_mode], input_))
|
||||
else:
|
||||
return _parallel_split_batch[tensor_parallel_mode](input_)
|
||||
else:
|
||||
return input_
|
23
colossalai/nn/layer/colossalai_layer/dropout.py
Normal file
23
colossalai/nn/layer/colossalai_layer/dropout.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch.nn as nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import conditional_context
|
||||
|
||||
from ..parallel_1d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
|
||||
|
||||
class Dropout(nn.Module):
|
||||
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.tensor_parallel = get_tensor_parallel_mode()
|
||||
if self.tensor_parallel == '1d':
|
||||
self.drop = Dropout1D(p, inplace)
|
||||
else:
|
||||
self.drop = nn.Dropout(p, inplace)
|
||||
|
||||
def forward(self, *args):
|
||||
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
|
||||
with cm:
|
||||
return self.drop(*args)
|
107
colossalai/nn/layer/colossalai_layer/embedding.py
Normal file
107
colossalai/nn/layer/colossalai_layer/embedding.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
|
||||
|
||||
_parallel_patchembedding = {
|
||||
'None': VanillaPatchEmbedding,
|
||||
'1d': VanillaPatchEmbedding,
|
||||
'2d': PatchEmbedding2D,
|
||||
'2.5d': PatchEmbedding2p5D,
|
||||
'3d': PatchEmbedding3D
|
||||
}
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel == 'None':
|
||||
self.embed = nn.Embedding(num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=get_current_device(),
|
||||
dtype=dtype,
|
||||
*args,
|
||||
**kwargs)
|
||||
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||
else:
|
||||
self.embed = _parallel_embedding[tensor_parallel](
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
||||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
in_chans: int,
|
||||
embed_size: int,
|
||||
dtype: dtype = None,
|
||||
flatten: bool = True,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
position_embed_initializer: Callable = init.zeros_()) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
self.embed = _parallel_patchembedding[tensor_parallel](
|
||||
img_size,
|
||||
patch_size,
|
||||
in_chans,
|
||||
embed_size,
|
||||
dtype=dtype,
|
||||
flatten=flatten,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
position_embed_initializer=position_embed_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.embed.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.embed.bias
|
||||
|
||||
@property
|
||||
def pos_embed(self):
|
||||
return self.embed.pos_embed
|
||||
|
||||
@property
|
||||
def cls_token(self):
|
||||
return self.embed.cls_token
|
||||
|
||||
def forward(self, *args):
|
||||
return self.embed(*args)
|
97
colossalai/nn/layer/colossalai_layer/linear.py
Normal file
97
colossalai/nn/layer/colossalai_layer/linear.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import dtype, nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
|
||||
|
||||
_parallel_classifier = {
|
||||
'None': VanillaClassifier,
|
||||
'1d': Classifier1D,
|
||||
'2d': Classifier2D,
|
||||
'2.5d': Classifier2p5D,
|
||||
'3d': Classifier3D
|
||||
}
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel == 'None':
|
||||
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
|
||||
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
|
||||
if bias:
|
||||
bias_initializer(self.layer.bias, fan_in=in_features)
|
||||
else:
|
||||
self.layer = _parallel_linear[tensor_parallel](
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
weight: nn.Parameter = None,
|
||||
bias: bool = True,
|
||||
dtype: dtype = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
|
||||
in_features,
|
||||
num_classes,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_initializer=weight_initializer,
|
||||
bias_initializer=bias_initializer,
|
||||
)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.layer.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.layer.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.layer(*args)
|
35
colossalai/nn/layer/colossalai_layer/normalization.py
Normal file
35
colossalai/nn/layer/colossalai_layer/normalization.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from torch import nn
|
||||
|
||||
from ... import init as init
|
||||
from ..parallel_1d import *
|
||||
from ..parallel_2d import *
|
||||
from ..parallel_2p5d import *
|
||||
from ..parallel_3d import *
|
||||
from ..utils import get_tensor_parallel_mode
|
||||
from ..vanilla import *
|
||||
|
||||
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
|
||||
super().__init__()
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel in ['None', '1d']:
|
||||
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
|
||||
else:
|
||||
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.norm.weight
|
||||
|
||||
@property
|
||||
def bias(self):
|
||||
return self.norm.bias
|
||||
|
||||
def forward(self, *args):
|
||||
return self.norm(*args)
|
Reference in New Issue
Block a user