mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
[kernel] fixed repeated loading of kernels (#2549)
* [kernel] fixed repeated loading of kernels * polish code * polish code
This commit is contained in:
parent
8438c35a5f
commit
dd14783f75
@ -1,3 +1,5 @@
|
||||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
from .scaled_softmax import FusedScaleMaskSoftmax
|
||||
from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
|
||||
|
||||
__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax']
|
||||
|
@ -9,24 +9,31 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
|
||||
|
||||
try:
|
||||
from colossalai._C import layer_norm
|
||||
except ImportError:
|
||||
layer_norm = None
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
try:
|
||||
from colossalai._C import layer_norm
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
|
||||
layer_norm = LayerNormBuilder().load()
|
||||
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.eps = eps
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
|
||||
global layer_norm
|
||||
if layer_norm is None:
|
||||
|
||||
layer_norm = LayerNormBuilder().load()
|
||||
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
||||
ctx.layernorm_op = layer_norm
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
|
||||
return output
|
||||
@ -34,12 +41,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
try:
|
||||
from colossalai._C import layer_norm
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
|
||||
layer_norm = LayerNormBuilder().load()
|
||||
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
|
@ -1,11 +1,17 @@
|
||||
"""This code from NVIDIA Megatron
|
||||
with some changes. """
|
||||
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||
from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder
|
||||
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
scaled_masked_softmax = None
|
||||
scaled_upper_triang_masked_softmax = None
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
padding = 1
|
||||
@ -23,7 +29,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
from colossalai.kernel import scaled_upper_triang_masked_softmax
|
||||
global scaled_upper_triang_masked_softmax
|
||||
if scaled_upper_triang_masked_softmax:
|
||||
scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load()
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
@ -33,8 +41,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
from colossalai.kernel import scaled_upper_triang_masked_softmax
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
|
||||
@ -52,30 +58,23 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
|
||||
# build and load kernel if not pre-built
|
||||
global scaled_masked_softmax
|
||||
if scaled_masked_softmax is None:
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
|
||||
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
return input_grads, None, None
|
||||
return input_grads, None, None, None
|
||||
|
||||
|
||||
class FusedScaleMaskSoftmax(nn.Module):
|
||||
@ -113,14 +112,6 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||
self.mask_func = mask_func
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
try:
|
||||
from colossalai._C import scaled_masked_softmax
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||
self.scaled_masked_softmax = scaled_masked_softmax
|
||||
|
||||
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
|
||||
|
||||
def forward(self, input, mask):
|
||||
@ -186,4 +177,4 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||
return probs
|
||||
|
||||
def get_batch_per_block(self, sq, sk, b, np):
|
||||
return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
|
@ -6,6 +6,23 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
|
||||
def print_rank_0(message):
|
||||
"""
|
||||
Print on only one process to avoid spamming.
|
||||
"""
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
if not dist.is_initialized():
|
||||
is_main_rank = True
|
||||
else:
|
||||
is_main_rank = dist.get_rank() == 0
|
||||
except ImportError:
|
||||
is_main_rank = True
|
||||
|
||||
if is_main_rank:
|
||||
print(message)
|
||||
|
||||
|
||||
class Builder(ABC):
|
||||
"""
|
||||
Builder is the base class to build extensions for PyTorch.
|
||||
@ -117,7 +134,7 @@ class Builder(ABC):
|
||||
try:
|
||||
op_module = self.import_op()
|
||||
if verbose:
|
||||
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
|
||||
print_rank_0(f"OP {self.prebuilt_import_path} already exists, skip building.")
|
||||
except ImportError:
|
||||
# construct the build directory
|
||||
import torch
|
||||
@ -130,9 +147,11 @@ class Builder(ABC):
|
||||
Path(build_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print("=========================================================================================")
|
||||
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
|
||||
print("=========================================================================================")
|
||||
print_rank_0(
|
||||
"=========================================================================================")
|
||||
print_rank_0(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
|
||||
print_rank_0(
|
||||
"=========================================================================================")
|
||||
|
||||
# load the kernel
|
||||
op_module = load(name=self.name,
|
||||
@ -146,7 +165,7 @@ class Builder(ABC):
|
||||
|
||||
build_duration = time.time() - start_build
|
||||
if verbose:
|
||||
print(f"Time to load {self.name} op: {build_duration} seconds")
|
||||
print_rank_0(f"Time to load {self.name} op: {build_duration} seconds")
|
||||
|
||||
return op_module
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user