diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index 8f857ff5d..1d5a6ce49 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,3 +1,5 @@ from .layer_norm import MixedFusedLayerNorm as LayerNorm from .multihead_attention import MultiHeadAttention -from .scaled_softmax import FusedScaleMaskSoftmax +from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax + +__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax'] diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 4be336388..40355a41e 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -9,24 +9,31 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn import init from torch.nn.parameter import Parameter +from colossalai.kernel.op_builder.layernorm import LayerNormBuilder + +try: + from colossalai._C import layer_norm +except ImportError: + layer_norm = None + class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, input, weight, bias, normalized_shape, eps): - try: - from colossalai._C import layer_norm - except ImportError: - from colossalai.kernel.op_builder.layernorm import LayerNormBuilder - layer_norm = LayerNormBuilder().load() - ctx.normalized_shape = normalized_shape ctx.eps = eps input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() + + global layer_norm + if layer_norm is None: + + layer_norm = LayerNormBuilder().load() output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) + ctx.layernorm_op = layer_norm ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @@ -34,12 +41,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, grad_output): - try: - from colossalai._C import layer_norm - except ImportError: - from colossalai.kernel.op_builder.layernorm import LayerNormBuilder - layer_norm = LayerNormBuilder().load() - input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \ diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 44d750c5c..580e5c81a 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -1,11 +1,17 @@ -"""This code from NVIDIA Megatron - with some changes. """ - import enum import torch import torch.nn as nn +from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder +from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder + +try: + from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax +except ImportError: + scaled_masked_softmax = None + scaled_upper_triang_masked_softmax = None + class AttnMaskType(enum.Enum): padding = 1 @@ -23,7 +29,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, inputs, scale): - from colossalai.kernel import scaled_upper_triang_masked_softmax + global scaled_upper_triang_masked_softmax + if scaled_upper_triang_masked_softmax: + scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load() scale_t = torch.tensor([scale]) softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) @@ -33,8 +41,6 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @staticmethod def backward(ctx, output_grads): - from colossalai.kernel import scaled_upper_triang_masked_softmax - softmax_results, scale_t = ctx.saved_tensors input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) @@ -52,30 +58,23 @@ class ScaledMaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, inputs, mask, scale): - try: - from colossalai._C import scaled_masked_softmax - except ImportError: - from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() - scale_t = torch.tensor([scale]) + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @staticmethod def backward(ctx, output_grads): - try: - from colossalai._C import scaled_masked_softmax - except ImportError: - from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() - softmax_results, scale_t = ctx.saved_tensors input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None + return input_grads, None, None, None class FusedScaleMaskSoftmax(nn.Module): @@ -113,14 +112,6 @@ class FusedScaleMaskSoftmax(nn.Module): self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale - - try: - from colossalai._C import scaled_masked_softmax - except ImportError: - from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() - self.scaled_masked_softmax = scaled_masked_softmax - assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled" def forward(self, input, mask): @@ -186,4 +177,4 @@ class FusedScaleMaskSoftmax(nn.Module): return probs def get_batch_per_block(self, sq, sk, b, np): - return self.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) + return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/op_builder/builder.py b/op_builder/builder.py index dc9ea8e11..e2fdde3af 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -6,6 +6,23 @@ from pathlib import Path from typing import List +def print_rank_0(message): + """ + Print on only one process to avoid spamming. + """ + try: + import torch.distributed as dist + if not dist.is_initialized(): + is_main_rank = True + else: + is_main_rank = dist.get_rank() == 0 + except ImportError: + is_main_rank = True + + if is_main_rank: + print(message) + + class Builder(ABC): """ Builder is the base class to build extensions for PyTorch. @@ -117,7 +134,7 @@ class Builder(ABC): try: op_module = self.import_op() if verbose: - print(f"OP {self.prebuilt_import_path} already exists, skip building.") + print_rank_0(f"OP {self.prebuilt_import_path} already exists, skip building.") except ImportError: # construct the build directory import torch @@ -130,9 +147,11 @@ class Builder(ABC): Path(build_directory).mkdir(parents=True, exist_ok=True) if verbose: - print("=========================================================================================") - print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now") - print("=========================================================================================") + print_rank_0( + "=========================================================================================") + print_rank_0(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now") + print_rank_0( + "=========================================================================================") # load the kernel op_module = load(name=self.name, @@ -146,7 +165,7 @@ class Builder(ABC): build_duration = time.time() - start_build if verbose: - print(f"Time to load {self.name} op: {build_duration} seconds") + print_rank_0(f"Time to load {self.name} op: {build_duration} seconds") return op_module