added skip_bias_add for non-tp linear

This commit is contained in:
zbian 2022-11-09 13:20:02 +08:00 committed by アマデウス
parent e5b1a0c9be
commit 653b0a620e
3 changed files with 493 additions and 440 deletions

View File

@ -1,147 +1,141 @@
import math import inspect
import inspect import math
from typing import Callable from typing import Callable
from colossalai.utils import get_current_device from torch import dtype, nn
from torch import dtype, nn
from colossalai.utils import get_current_device
from ... import init as init
from ..parallel_1d import * from ... import init as init
from ..parallel_2d import * from ..parallel_1d import *
from ..parallel_2p5d import * from ..parallel_2d import *
from ..parallel_3d import * from ..parallel_2p5d import *
from ..utils import get_tensor_parallel_mode from ..parallel_3d import *
from ..vanilla import * from ..utils import get_tensor_parallel_mode
from ._utils import ColossalaiModule from ..vanilla import *
from ._utils import ColossalaiModule
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_linear = {None: VanillaLinear, '1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
None: VanillaClassifier, _parallel_classifier = {
'1d': Classifier1D, None: VanillaClassifier,
'2d': Classifier2D, '1d': Classifier1D,
'2.5d': Classifier2p5D, '2d': Classifier2D,
'3d': Classifier3D '2.5d': Classifier2p5D,
} '3d': Classifier3D
}
_vocab_parallel_classifier = {
'1d': VocabParallelClassifier1D, _vocab_parallel_classifier = {
'2d': VocabParallelClassifier2D, '1d': VocabParallelClassifier1D,
'2.5d': VocabParallelClassifier2p5D, '2d': VocabParallelClassifier2D,
'3d': VocabParallelClassifier3D '2.5d': VocabParallelClassifier2p5D,
} '3d': VocabParallelClassifier3D
}
class Linear(ColossalaiModule):
"""Linear layer of colossalai. class Linear(ColossalaiModule):
"""Linear layer of colossalai.
Args:
in_features (int): size of each input sample. Args:
out_features (int): size of each output sample. in_features (int): size of each input sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``. out_features (int): size of each output sample.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
weight_initializer (:class:`typing.Callable`, optional): dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer. weight_initializer (:class:`typing.Callable`, optional):
bias_initializer (:class:`typing.Callable`, optional): The initializer of weight, defaults to kaiming uniform initializer.
The initializer of bias, defaults to xavier uniform initializer. bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
Note: ``kwargs`` would contain different parameters when you use different parallelisms.
Note: ``kwargs`` would contain different parameters when you use different parallelisms.
The ``kwargs`` should contain parameters below:
:: The ``kwargs`` should contain parameters below:
::
Linear1D:
gather_output: bool (optional, default to be false) Linear1D:
skip_bias_add: bool (optional, default to be false) gather_output: bool (optional, default to be false)
Linear2D: skip_bias_add: bool (optional, default to be false)
skip_bias_add: bool (optional, default to be false) Linear2D:
Linear2p5D: skip_bias_add: bool (optional, default to be false)
skip_bias_add: bool (optional, default to be false) Linear2p5D:
Linear3D: skip_bias_add: bool (optional, default to be false)
None Linear3D:
None
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_. More details about ``initializer`` please refer to
""" `init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int, def __init__(self,
out_features: int, in_features: int,
bias: bool = True, out_features: int,
dtype: dtype = None, bias: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), dtype: dtype = None,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
**kwargs) -> None: bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
tensor_parallel = get_tensor_parallel_mode() **kwargs) -> None:
if tensor_parallel is None: tensor_parallel = get_tensor_parallel_mode()
layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device()) linear_cls = _parallel_linear[tensor_parallel]
weight_initializer(layer.weight, fan_in=in_features, fan_out=out_features) gather_output = kwargs.pop('gather_output', None)
if layer.bias is not None: if 'gather_output' in inspect.signature(linear_cls.__init__).parameters.keys(): # gather_out arg is available
bias_initializer(layer.bias, fan_in=in_features) kwargs['gather_output'] = gather_output
else: layer = linear_cls(
linear_cls = _parallel_linear[tensor_parallel] in_features,
gather_output = kwargs.pop('gather_output', None) out_features,
if 'gather_output' in inspect.signature( bias=bias,
linear_cls.__init__).parameters.keys(): # gather_out arg is available dtype=dtype,
kwargs['gather_output'] = gather_output weight_initializer=weight_initializer,
layer = linear_cls( bias_initializer=bias_initializer,
in_features, **kwargs,
out_features, )
bias=bias, super().__init__(layer)
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer, class Classifier(ColossalaiModule):
**kwargs, """Classifier layer of colossalai.
)
super().__init__(layer) Args:
in_features (int): size of each input sample.
num_classes (int): number of classes.
class Classifier(ColossalaiModule): weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
"""Classifier layer of colossalai. bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
Args: weight_initializer (:class:`typing.Callable`, optional):
in_features (int): size of each input sample. The initializer of weight, defaults to kaiming uniform initializer.
num_classes (int): number of classes. bias_initializer (:class:`typing.Callable`, optional):
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. The initializer of bias, defaults to xavier uniform initializer.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. More details about ``initializer`` please refer to
weight_initializer (:class:`typing.Callable`, optional): `init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
The initializer of weight, defaults to kaiming uniform initializer. """
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer. def __init__(self,
in_features: int,
More details about ``initializer`` please refer to num_classes: int,
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_. weight: nn.Parameter = None,
""" bias: bool = True,
dtype: dtype = None,
def __init__(self, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
in_features: int, bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
num_classes: int, vocab_parallel_limit: int = 2048) -> None:
weight: nn.Parameter = None, tensor_parallel = get_tensor_parallel_mode()
bias: bool = True, if num_classes <= vocab_parallel_limit or tensor_parallel is None:
dtype: dtype = None, layer = _parallel_classifier[tensor_parallel](
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), in_features,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), num_classes,
vocab_parallel_limit: int = 2048) -> None: weight=weight,
tensor_parallel = get_tensor_parallel_mode() bias=bias,
if num_classes <= vocab_parallel_limit or tensor_parallel is None: dtype=dtype,
layer = _parallel_classifier[tensor_parallel]( weight_initializer=weight_initializer,
in_features, bias_initializer=bias_initializer,
num_classes, )
weight=weight, else:
bias=bias, layer = _vocab_parallel_classifier[tensor_parallel](
dtype=dtype, in_features,
weight_initializer=weight_initializer, num_classes,
bias_initializer=bias_initializer, weight=weight,
) bias=bias,
else: dtype=dtype,
layer = _vocab_parallel_classifier[tensor_parallel]( weight_initializer=weight_initializer,
in_features, bias_initializer=bias_initializer,
num_classes, )
weight=weight, super().__init__(layer)
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
super().__init__(layer)

View File

@ -1,6 +1,14 @@
from .layers import (DropPath, VanillaClassifier, VanillaLayerNorm, VanillaPatchEmbedding, WrappedDropout, from .layers import (
WrappedDropPath) DropPath,
VanillaClassifier,
VanillaLayerNorm,
VanillaLinear,
VanillaPatchEmbedding,
WrappedDropout,
WrappedDropPath,
)
__all__ = [ __all__ = [
"VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath" "VanillaLayerNorm", "VanillaPatchEmbedding", "VanillaClassifier", "DropPath", "WrappedDropout", "WrappedDropPath",
"VanillaLinear"
] ]

View File

@ -1,290 +1,341 @@
import math import math
from typing import Callable from typing import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.context import seed from torch import Tensor
from colossalai.nn import init as init from torch import nn as nn
from colossalai.registry import LAYERS from torch.nn.parameter import Parameter
from colossalai.utils.cuda import get_current_device
from torch import Tensor from colossalai.context import seed
from torch import nn as nn from colossalai.nn import init as init
from colossalai.registry import LAYERS
from ..utils import to_2tuple from colossalai.utils.cuda import get_current_device
from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
def drop_path(x, drop_prob: float = 0., training: bool = False):
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
'survival rate' as the argument. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
Args: 'survival rate' as the argument.
drop_prob (float, optional): probability of dropping path, defaults 0.0.
training (bool, optional): whether in training progress, defaults False. Args:
""" drop_prob (float, optional): probability of dropping path, defaults 0.0.
if drop_prob == 0. or not training: training (bool, optional): whether in training progress, defaults False.
return x """
keep_prob = 1 - drop_prob if drop_prob == 0. or not training:
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets return x
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) keep_prob = 1 - drop_prob
random_tensor.floor_() # binarize shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
output = x.div(keep_prob) * random_tensor random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
return output random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). class DropPath(nn.Module):
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args: Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
drop_prob (float, optional): probability of dropping path, defaults None.
""" Args:
drop_prob (float, optional): probability of dropping path, defaults None.
def __init__(self, drop_prob=None): """
super(DropPath, self).__init__()
self.drop_prob = drop_prob def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
def forward(self, x): self.drop_prob = drop_prob
return drop_path(x, self.drop_prob, self.training)
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class WrappedDropout(nn.Module):
r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes
some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each class WrappedDropout(nn.Module):
channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of r"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. During training, randomly zeroes
1/(1-p) during training. This means that during evaluation the module simply computes an identity function. some elements of the input tensor with probability p using samples from a Bernoulli distribution. Each
channel will be zeroed out independently on every forward call. Furthermore, the outputs are scaled by a factor of
Args: 1/(1-p) during training. This means that during evaluation the module simply computes an identity function.
p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False. Args:
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. p (float, optional): probability of an element to be zeroed, defaults 0.5.
inplace (bool, optional): whether to do dropout in-place, default to be False.
Note: mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ Note:
""" The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): """
super().__init__()
if p < 0 or p > 1: def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
raise ValueError("dropout probability has to be between 0 and 1, " super().__init__()
"but got {}".format(p)) if p < 0 or p > 1:
self.p = p raise ValueError("dropout probability has to be between 0 and 1, "
self.inplace = inplace "but got {}".format(p))
if mode is None: self.p = p
self.func = self.nonefunc self.inplace = inplace
else: if mode is None:
self.func = self.normalfunc self.func = self.nonefunc
self.mode = mode else:
self.func = self.normalfunc
def nonefunc(self, inputs): self.mode = mode
return F.dropout(inputs, self.p, self.training, self.inplace)
def nonefunc(self, inputs):
def normalfunc(self, inputs): return F.dropout(inputs, self.p, self.training, self.inplace)
with seed(self.mode):
return F.dropout(inputs, self.p, self.training, self.inplace) def normalfunc(self, inputs):
with seed(self.mode):
def forward(self, inputs): return F.dropout(inputs, self.p, self.training, self.inplace)
return self.func(inputs)
def forward(self, inputs):
return self.func(inputs)
class WrappedDropPath(nn.Module):
r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager. class WrappedDropPath(nn.Module):
r"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args: Here, it is wrapped with the context of seed manager.
p (float, optional): probability of dropping path, defaults 0.0.
mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. Args:
p (float, optional): probability of dropping path, defaults 0.0.
Note: mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_ Note:
""" The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
def __init__(self, p: float = 0., mode=None): """
super().__init__()
self.p = p def __init__(self, p: float = 0., mode=None):
self.mode = mode super().__init__()
if self.mode is None: self.p = p
self.func = self.nonefunc self.mode = mode
else: if self.mode is None:
self.func = self.normalfunc self.func = self.nonefunc
self.mode = mode else:
self.func = self.normalfunc
def nonefunc(self, inputs): self.mode = mode
return drop_path(inputs, self.p, self.training)
def nonefunc(self, inputs):
def normalfunc(self, inputs): return drop_path(inputs, self.p, self.training)
with seed(self.mode):
return drop_path(inputs, self.p, self.training) def normalfunc(self, inputs):
with seed(self.mode):
def forward(self, inputs): return drop_path(inputs, self.p, self.training)
return self.func(inputs)
def forward(self, inputs):
return self.func(inputs)
@LAYERS.register_module
class VanillaPatchEmbedding(nn.Module):
r""" @LAYERS.register_module
2D Image to Patch Embedding class VanillaPatchEmbedding(nn.Module):
r"""
Args: 2D Image to Patch Embedding
img_size (int): image size.
patch_size (int): patch size. Args:
in_chans (int): number of channels of input image. img_size (int): image size.
embed_size (int): size of embedding. patch_size (int): patch size.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. in_chans (int): number of channels of input image.
flatten (bool, optional): whether to flatten output tensor, defaults to True. embed_size (int): size of embedding.
weight_initializer (:class:`typing.Callable`, optional): dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer. flatten (bool, optional): whether to flatten output tensor, defaults to True.
bias_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer. The initializer of weight, defaults to kaiming uniform initializer.
position_embed_initializer (:class:`typing.Callable`, optional): bias_initializer (:class:`typing.Callable`, optional):
The initializer of position embedding, defaults to zeros initializer. The initializer of bias, defaults to xavier uniform initializer.
position_embed_initializer (:class:`typing.Callable`, optional):
More details about initializer please refer to The initializer of position embedding, defaults to zeros initializer.
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
""" More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
def __init__(self, """
img_size: int,
patch_size: int, def __init__(self,
in_chans: int, img_size: int,
embed_size: int, patch_size: int,
flatten: bool = True, in_chans: int,
dtype: torch.dtype = None, embed_size: int,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), flatten: bool = True,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), dtype: torch.dtype = None,
position_embed_initializer: Callable = init.zeros_()): weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
super().__init__() bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
img_size = to_2tuple(img_size) position_embed_initializer: Callable = init.zeros_()):
patch_size = to_2tuple(patch_size) super().__init__()
self.img_size = img_size img_size = to_2tuple(img_size)
self.patch_size = patch_size patch_size = to_2tuple(patch_size)
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size
self.num_patches = self.grid_size[0] * self.grid_size[1] self.patch_size = patch_size
self.flatten = flatten self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.weight = nn.Parameter( self.flatten = flatten
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) self.weight = nn.Parameter(
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype)) torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter( self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)) self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype))
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer): self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
bias_initializer(self.bias, fan_in=fan_in) fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
position_embed_initializer(self.pos_embed) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tensor: position_embed_initializer(self.pos_embed)
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \ def forward(self, input_: Tensor) -> Tensor:
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." B, C, H, W = input_.shape
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size) assert H == self.img_size[0] and W == self.img_size[1], \
if self.flatten: f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
cls_token = self.cls_token.expand(output.shape[0], -1, -1) output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed cls_token = self.cls_token.expand(output.shape[0], -1, -1)
return output output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed
return output
@LAYERS.register_module
class VanillaClassifier(nn.Module):
r"""Dense linear classifier. @LAYERS.register_module
class VanillaClassifier(nn.Module):
Args: r"""Dense linear classifier.
in_features (int): size of each input sample.
num_classes (int): number of classes. Args:
weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None. in_features (int): size of each input sample.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. num_classes (int): number of classes.
flatten (bool, optional): whether to flatten output tensor, defaults to True. weight (:class:`torch.nn.Parameter`, optional): weight of the classifier, defaults to None.
weight_initializer (:class:`typing.Callable`, optional): dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
The initializer of weight, defaults to kaiming uniform initializer. flatten (bool, optional): whether to flatten output tensor, defaults to True.
bias_initializer (:class:`typing.Callable`, optional): weight_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer. The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
More details about initializer please refer to The initializer of bias, defaults to xavier uniform initializer.
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
""" More details about initializer please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
def __init__(self, """
in_features: int,
num_classes: int, def __init__(self,
weight: nn.Parameter = None, in_features: int,
bias: bool = True, num_classes: int,
dtype: torch.dtype = None, weight: nn.Parameter = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), bias: bool = True,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): dtype: torch.dtype = None,
super().__init__() weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
self.in_features = in_features bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
self.num_classes = num_classes super().__init__()
self.in_features = in_features
if weight is not None: self.num_classes = num_classes
self.weight = weight
self.has_weight = False if weight is not None:
else: self.weight = weight
self.weight = nn.Parameter( self.has_weight = False
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)) else:
self.has_weight = True self.weight = nn.Parameter(
if bias: torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.has_weight = True
else: if bias:
self.bias = None self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.reset_parameters(weight_initializer, bias_initializer) self.bias = None
def reset_parameters(self, weight_initializer, bias_initializer): self.reset_parameters(weight_initializer, bias_initializer)
fan_in, fan_out = self.in_features, self.num_classes
def reset_parameters(self, weight_initializer, bias_initializer):
if self.has_weight: fan_in, fan_out = self.in_features, self.num_classes
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.has_weight:
if self.bias is not None: weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
if self.bias is not None:
def forward(self, input_: Tensor) -> Tensor: bias_initializer(self.bias, fan_in=fan_in)
return F.linear(input_, self.weight, self.bias)
def forward(self, input_: Tensor) -> Tensor:
return F.linear(input_, self.weight, self.bias)
@LAYERS.register_module
class VanillaLayerNorm(nn.Module):
r""" @LAYERS.register_module
Layer Normalization for colossalai class VanillaLayerNorm(nn.Module):
r"""
Args: Layer Normalization for colossalai
normalized_shape (int): input shape from an expected input of size.
:math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] Args:
\times \ldots \times \text{normalized_shape}[-1]]` normalized_shape (int): input shape from an expected input of size.
If a single integer is used, it is treated as a singleton list, and this module will :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1]
normalize over the last dimension which is expected to be of that specific size. \times \ldots \times \text{normalized_shape}[-1]]`
eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. If a single integer is used, it is treated as a singleton list, and this module will
bias (bool, optional): Whether to add a bias, defaults to ``True``. normalize over the last dimension which is expected to be of that specific size.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. eps (float): a value added to the denominator for numerical stability, defaults to 1e-05.
""" bias (bool, optional): Whether to add a bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): """
super().__init__()
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
self.normalized_shape = (normalized_shape,) super().__init__()
self.variance_epsilon = eps
self.normalized_shape = (normalized_shape,)
factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.variance_epsilon = eps
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if bias:
self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs)) self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
else: if bias:
self.bias = None self.bias = nn.Parameter(torch.zeros(normalized_shape, **factory_kwargs))
else:
def forward(self, x: Tensor) -> Tensor: self.bias = None
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
def forward(self, x: Tensor) -> Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.variance_epsilon)
@LAYERS.register_module
class VanillaLinear(nn.Module):
"""Linear layer.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
skip_bias_add: bool (optional, default to be false).
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (:class:`typing.Callable`, optional):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
weight_initializer(self.weight, fan_in=in_features, fan_out=out_features)
if self.bias is not None:
bias_initializer(self.bias, fan_in=in_features)
def forward(self, input: Tensor) -> Tensor:
if not self.skip_bias_add:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(input, self.weight), self.bias