mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
add colossalai kernel module (#55)
This commit is contained in:
3
colossalai/kernel/jit/__init__.py
Normal file
3
colossalai/kernel/jit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .option import _set_jit_fusion_options
|
||||
|
||||
_set_jit_fusion_options()
|
24
colossalai/kernel/jit/bias_dropout_add.py
Normal file
24
colossalai/kernel/jit/bias_dropout_add.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_train(x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return bias_dropout_add(x, bias, residual, prob, True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def bias_dropout_add_fused_inference(x: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return bias_dropout_add(x, bias, residual, prob, False)
|
41
colossalai/kernel/jit/bias_gelu.py
Normal file
41
colossalai/kernel/jit/bias_gelu.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
|
||||
###### BIAS GELU FUSION/ NO AUTOGRAD ################
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
# 1/sqrt(2) -> 0.70710678
|
||||
# sqrt(2/pi) -> 0.79788456
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
|
||||
@torch.jit.script
|
||||
def bias_gelu(bias, y):
|
||||
x = bias + y
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, bias, y):
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return ff*g
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(bias, input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_back(grad_output, bias, input)
|
||||
return tmp, tmp
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
28
colossalai/kernel/jit/option.py
Normal file
28
colossalai/kernel/jit/option.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
|
||||
JIT_OPTIONS_SET = False
|
||||
|
||||
def _set_jit_fusion_options():
|
||||
"""Set PyTorch JIT layer fusion options."""
|
||||
global JIT_OPTIONS_SET
|
||||
if JIT_OPTIONS_SET == False:
|
||||
# flags required to enable jit fusion kernels
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
|
||||
# nvfuser
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._jit_set_profiling_mode(True)
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_nvfuser_enabled(True)
|
||||
torch._C._debug_set_autodiff_subgraph_inlining(False)
|
||||
else:
|
||||
# legacy pytorch fuser
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
|
||||
JIT_OPTIONS_SET = True
|
Reference in New Issue
Block a user