mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-04 21:29:41 +00:00
* [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) * [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> * [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943) * [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945) * [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946) Co-authored-by: jnbai <897086360@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong <jiatong.han@u.nus.edu> * [NFC] polish colossalai/builder/pipeline.py code style (#951) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953) Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954) * [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956) Co-authored-by: RichardoLuo <14049555596@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962) * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963) Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964) * [NFC] polish __init__.py code style (#965) * [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968) code style * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970) * [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972) * [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977) * [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979) * [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980) * [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com> Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com> Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com> Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com> Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com> Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com> Co-authored-by: jnbai <897086360@qq.com> Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com> Co-authored-by: Jiatong <jiatong.han@u.nus.edu> Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com> Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com> Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com> Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Co-authored-by: RichardoLuo <14049555596@qq.com> Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com> Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com> Co-authored-by: MaxT <854721132@qq.com> Co-authored-by: superhao1995 <804673818@qq.com> Co-authored-by: ziyu huang <huang0ziyu@gmail.com> Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com> Co-authored-by: lucasliunju <lucasliunju@gmail.com> Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com> Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zirui Zhu <zhuzr21@gmail.com> Co-authored-by: Ofey Chan <ofey206@gmail.com> Co-authored-by: DouJS <dujiangsu@163.com> Co-authored-by: Jie Zhu <chore.08-protist@icloud.com> Co-authored-by: shenggan <csg19971016@gmail.com> Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
"""This code is from NVIDIA apex:
|
|
https://github.com/NVIDIA/apex
|
|
with some changes. """
|
|
|
|
import numbers
|
|
import torch
|
|
from torch.nn.parameter import Parameter
|
|
from torch.nn import init
|
|
from torch.cuda.amp import custom_fwd, custom_bwd
|
|
import importlib
|
|
|
|
global colossal_layer_norm_cuda
|
|
colossal_layer_norm_cuda = None
|
|
|
|
|
|
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float32)
|
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
|
|
|
ctx.normalized_shape = normalized_shape
|
|
ctx.eps = eps
|
|
input_ = input.contiguous()
|
|
weight_ = weight.contiguous()
|
|
bias_ = bias.contiguous()
|
|
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
|
|
ctx.eps)
|
|
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
@custom_bwd
|
|
def backward(ctx, grad_output):
|
|
|
|
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
|
grad_input = grad_weight = grad_bias = None
|
|
grad_input, grad_weight, grad_bias \
|
|
= colossal_layer_norm_cuda.backward_affine(
|
|
grad_output.contiguous(), mean, invvar,
|
|
input_, ctx.normalized_shape,
|
|
weight_, bias_, ctx.eps)
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None
|
|
|
|
|
|
class MixedFusedLayerNorm(torch.nn.Module):
|
|
|
|
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
|
|
super(MixedFusedLayerNorm, self).__init__()
|
|
|
|
global colossal_layer_norm_cuda
|
|
if colossal_layer_norm_cuda is None:
|
|
try:
|
|
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
|
|
except ImportError:
|
|
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
|
|
|
|
if isinstance(normalized_shape, numbers.Integral):
|
|
normalized_shape = (normalized_shape,)
|
|
self.normalized_shape = torch.Size(normalized_shape)
|
|
self.eps = eps
|
|
self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
|
|
self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
|
|
init.ones_(self.weight)
|
|
init.zeros_(self.bias)
|
|
|
|
def forward(self, input):
|
|
|
|
return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
|
|
|
|
def __repr__(self):
|
|
return f'MixedFusedLayerNorm(normalized_shape={self.normalized_shape}, eps={self.eps})'
|