mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 13:11:27 +00:00
[kernel] fixed repeated loading of kernels (#2549)
* [kernel] fixed repeated loading of kernels * polish code * polish code
This commit is contained in:
parent
8438c35a5f
commit
dd14783f75
@ -1,3 +1,5 @@
|
|||||||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||||
from .multihead_attention import MultiHeadAttention
|
from .multihead_attention import MultiHeadAttention
|
||||||
from .scaled_softmax import FusedScaleMaskSoftmax
|
from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
|
||||||
|
|
||||||
|
__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax']
|
||||||
|
@ -9,24 +9,31 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||||||
from torch.nn import init
|
from torch.nn import init
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai._C import layer_norm
|
||||||
|
except ImportError:
|
||||||
|
layer_norm = None
|
||||||
|
|
||||||
|
|
||||||
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd(cast_inputs=torch.float32)
|
@custom_fwd(cast_inputs=torch.float32)
|
||||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||||
try:
|
|
||||||
from colossalai._C import layer_norm
|
|
||||||
except ImportError:
|
|
||||||
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
|
|
||||||
layer_norm = LayerNormBuilder().load()
|
|
||||||
|
|
||||||
ctx.normalized_shape = normalized_shape
|
ctx.normalized_shape = normalized_shape
|
||||||
ctx.eps = eps
|
ctx.eps = eps
|
||||||
input_ = input.contiguous()
|
input_ = input.contiguous()
|
||||||
weight_ = weight.contiguous()
|
weight_ = weight.contiguous()
|
||||||
bias_ = bias.contiguous()
|
bias_ = bias.contiguous()
|
||||||
|
|
||||||
|
global layer_norm
|
||||||
|
if layer_norm is None:
|
||||||
|
|
||||||
|
layer_norm = LayerNormBuilder().load()
|
||||||
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
|
||||||
|
ctx.layernorm_op = layer_norm
|
||||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -34,12 +41,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
try:
|
|
||||||
from colossalai._C import layer_norm
|
|
||||||
except ImportError:
|
|
||||||
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
|
|
||||||
layer_norm = LayerNormBuilder().load()
|
|
||||||
|
|
||||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||||
grad_input = grad_weight = grad_bias = None
|
grad_input = grad_weight = grad_bias = None
|
||||||
grad_input, grad_weight, grad_bias \
|
grad_input, grad_weight, grad_bias \
|
||||||
|
@ -1,11 +1,17 @@
|
|||||||
"""This code from NVIDIA Megatron
|
|
||||||
with some changes. """
|
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
||||||
|
from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder
|
||||||
|
|
||||||
|
try:
|
||||||
|
from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
|
||||||
|
except ImportError:
|
||||||
|
scaled_masked_softmax = None
|
||||||
|
scaled_upper_triang_masked_softmax = None
|
||||||
|
|
||||||
|
|
||||||
class AttnMaskType(enum.Enum):
|
class AttnMaskType(enum.Enum):
|
||||||
padding = 1
|
padding = 1
|
||||||
@ -23,7 +29,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, inputs, scale):
|
def forward(ctx, inputs, scale):
|
||||||
from colossalai.kernel import scaled_upper_triang_masked_softmax
|
global scaled_upper_triang_masked_softmax
|
||||||
|
if scaled_upper_triang_masked_softmax:
|
||||||
|
scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load()
|
||||||
|
|
||||||
scale_t = torch.tensor([scale])
|
scale_t = torch.tensor([scale])
|
||||||
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||||
@ -33,8 +41,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, output_grads):
|
def backward(ctx, output_grads):
|
||||||
from colossalai.kernel import scaled_upper_triang_masked_softmax
|
|
||||||
|
|
||||||
softmax_results, scale_t = ctx.saved_tensors
|
softmax_results, scale_t = ctx.saved_tensors
|
||||||
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||||
|
|
||||||
@ -52,30 +58,23 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, inputs, mask, scale):
|
def forward(ctx, inputs, mask, scale):
|
||||||
try:
|
|
||||||
from colossalai._C import scaled_masked_softmax
|
|
||||||
except ImportError:
|
|
||||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
|
||||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
|
||||||
|
|
||||||
scale_t = torch.tensor([scale])
|
scale_t = torch.tensor([scale])
|
||||||
|
|
||||||
|
# build and load kernel if not pre-built
|
||||||
|
global scaled_masked_softmax
|
||||||
|
if scaled_masked_softmax is None:
|
||||||
|
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
||||||
|
|
||||||
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||||
ctx.save_for_backward(softmax_results, scale_t)
|
ctx.save_for_backward(softmax_results, scale_t)
|
||||||
return softmax_results
|
return softmax_results
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, output_grads):
|
def backward(ctx, output_grads):
|
||||||
try:
|
|
||||||
from colossalai._C import scaled_masked_softmax
|
|
||||||
except ImportError:
|
|
||||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
|
||||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
|
||||||
|
|
||||||
softmax_results, scale_t = ctx.saved_tensors
|
softmax_results, scale_t = ctx.saved_tensors
|
||||||
|
|
||||||
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||||
return input_grads, None, None
|
return input_grads, None, None, None
|
||||||
|
|
||||||
|
|
||||||
class FusedScaleMaskSoftmax(nn.Module):
|
class FusedScaleMaskSoftmax(nn.Module):
|
||||||
@ -113,14 +112,6 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||||||
self.mask_func = mask_func
|
self.mask_func = mask_func
|
||||||
self.softmax_in_fp32 = softmax_in_fp32
|
self.softmax_in_fp32 = softmax_in_fp32
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
try:
|
|
||||||
from colossalai._C import scaled_masked_softmax
|
|
||||||
except ImportError:
|
|
||||||
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
|
|
||||||
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
|
|
||||||
self.scaled_masked_softmax = scaled_masked_softmax
|
|
||||||
|
|
||||||
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
|
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
|
||||||
|
|
||||||
def forward(self, input, mask):
|
def forward(self, input, mask):
|
||||||
@ -186,4 +177,4 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||||||
return probs
|
return probs
|
||||||
|
|
||||||
def get_batch_per_block(self, sq, sk, b, np):
|
def get_batch_per_block(self, sq, sk, b, np):
|
||||||
return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||||
|
@ -6,6 +6,23 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def print_rank_0(message):
|
||||||
|
"""
|
||||||
|
Print on only one process to avoid spamming.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import torch.distributed as dist
|
||||||
|
if not dist.is_initialized():
|
||||||
|
is_main_rank = True
|
||||||
|
else:
|
||||||
|
is_main_rank = dist.get_rank() == 0
|
||||||
|
except ImportError:
|
||||||
|
is_main_rank = True
|
||||||
|
|
||||||
|
if is_main_rank:
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
|
||||||
class Builder(ABC):
|
class Builder(ABC):
|
||||||
"""
|
"""
|
||||||
Builder is the base class to build extensions for PyTorch.
|
Builder is the base class to build extensions for PyTorch.
|
||||||
@ -117,7 +134,7 @@ class Builder(ABC):
|
|||||||
try:
|
try:
|
||||||
op_module = self.import_op()
|
op_module = self.import_op()
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"OP {self.prebuilt_import_path} already exists, skip building.")
|
print_rank_0(f"OP {self.prebuilt_import_path} already exists, skip building.")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# construct the build directory
|
# construct the build directory
|
||||||
import torch
|
import torch
|
||||||
@ -130,9 +147,11 @@ class Builder(ABC):
|
|||||||
Path(build_directory).mkdir(parents=True, exist_ok=True)
|
Path(build_directory).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print("=========================================================================================")
|
print_rank_0(
|
||||||
print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
|
"=========================================================================================")
|
||||||
print("=========================================================================================")
|
print_rank_0(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
|
||||||
|
print_rank_0(
|
||||||
|
"=========================================================================================")
|
||||||
|
|
||||||
# load the kernel
|
# load the kernel
|
||||||
op_module = load(name=self.name,
|
op_module = load(name=self.name,
|
||||||
@ -146,7 +165,7 @@ class Builder(ABC):
|
|||||||
|
|
||||||
build_duration = time.time() - start_build
|
build_duration = time.time() - start_build
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"Time to load {self.name} op: {build_duration} seconds")
|
print_rank_0(f"Time to load {self.name} op: {build_duration} seconds")
|
||||||
|
|
||||||
return op_module
|
return op_module
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user