From 5c3843dc98007bac6b8b0ee6d0a9f120589bece4 Mon Sep 17 00:00:00 2001 From: shenggan Date: Tue, 21 Dec 2021 12:19:52 +0800 Subject: [PATCH] add colossalai kernel module (#55) --- colossalai/kernel/__init__.py | 8 + colossalai/kernel/cuda_native/__init__.py | 17 + colossalai/kernel/cuda_native/builder.py | 114 ++ colossalai/kernel/cuda_native/csrc/compat.h | 13 + .../cuda_native/csrc/kernels/cross_entropy.cu | 191 +++ .../csrc/kernels/cublas_wrappers.cu | 87 ++ .../cuda_native/csrc/kernels/cuda_util.cu | 169 +++ .../csrc/kernels/dropout_kernels.cu | 1001 ++++++++++++++ .../csrc/kernels/general_kernels.cu | 232 ++++ .../csrc/kernels/include/block_reduce.h | 312 +++++ .../csrc/kernels/include/context.h | 36 + .../kernels/include/cross_entropy_layer.h | 46 + .../csrc/kernels/include/cublas_wrappers.h | 40 + .../csrc/kernels/include/cuda_util.h | 34 + .../csrc/kernels/include/dropout.h | 95 ++ .../csrc/kernels/include/feed_forward.h | 68 + .../csrc/kernels/include/kernels.h | 274 ++++ .../csrc/kernels/include/ls_cub.cuh | 12 + .../csrc/kernels/include/normalize_layer.h | 65 + .../csrc/kernels/include/softmax.h | 44 + .../csrc/kernels/include/strided_batch_gemm.h | 99 ++ .../csrc/kernels/normalize_kernels.cu | 1160 +++++++++++++++++ .../csrc/kernels/softmax_kernels.cu | 366 ++++++ .../csrc/kernels/transform_kernels.cu | 314 +++++ .../cuda_native/csrc/layer_norm_cuda.cpp | 185 +++ .../csrc/layer_norm_cuda_kernel.cu | 813 ++++++++++++ .../csrc/multihead_attention_1d.cpp | 364 ++++++ .../cuda_native/csrc/multihead_attention_1d.h | 153 +++ .../csrc/scaled_masked_softmax.cpp | 84 ++ .../cuda_native/csrc/scaled_masked_softmax.h | 492 +++++++ .../csrc/scaled_masked_softmax_cuda.cu | 104 ++ .../scaled_upper_triang_masked_softmax.cpp | 59 + .../csrc/scaled_upper_triang_masked_softmax.h | 500 +++++++ ...scaled_upper_triang_masked_softmax_cuda.cu | 85 ++ .../kernel/cuda_native/csrc/type_shim.h | 73 ++ colossalai/kernel/cuda_native/layer_norm.py | 69 + .../kernel/cuda_native/multihead_attention.py | 270 ++++ .../kernel/cuda_native/scaled_softmax.py | 184 +++ colossalai/kernel/jit/__init__.py | 3 + colossalai/kernel/jit/bias_dropout_add.py | 24 + colossalai/kernel/jit/bias_gelu.py | 41 + colossalai/kernel/jit/option.py | 28 + setup.py | 1 + 43 files changed, 8329 insertions(+) create mode 100644 colossalai/kernel/__init__.py create mode 100644 colossalai/kernel/cuda_native/__init__.py create mode 100644 colossalai/kernel/cuda_native/builder.py create mode 100644 colossalai/kernel/cuda_native/csrc/compat.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/context.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu create mode 100644 colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu create mode 100644 colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu create mode 100644 colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h create mode 100644 colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h create mode 100644 colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu create mode 100644 colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h create mode 100644 colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu create mode 100644 colossalai/kernel/cuda_native/csrc/type_shim.h create mode 100644 colossalai/kernel/cuda_native/layer_norm.py create mode 100644 colossalai/kernel/cuda_native/multihead_attention.py create mode 100644 colossalai/kernel/cuda_native/scaled_softmax.py create mode 100644 colossalai/kernel/jit/__init__.py create mode 100644 colossalai/kernel/jit/bias_dropout_add.py create mode 100644 colossalai/kernel/jit/bias_gelu.py create mode 100644 colossalai/kernel/jit/option.py diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py new file mode 100644 index 000000000..32bab15e5 --- /dev/null +++ b/colossalai/kernel/__init__.py @@ -0,0 +1,8 @@ +from .jit.bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from .jit.bias_gelu import bias_gelu_impl +from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention + +__all__ = [ + "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", + "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention" +] diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py new file mode 100644 index 000000000..33394224f --- /dev/null +++ b/colossalai/kernel/cuda_native/__init__.py @@ -0,0 +1,17 @@ +from .builder import _build_cuda_native_kernel + +CUDA_NATIVE_KERNEL_BUILD = False + + +def build_cuda_native_kernel(): + global CUDA_NATIVE_KERNEL_BUILD + if CUDA_NATIVE_KERNEL_BUILD == False: + _build_cuda_native_kernel() + CUDA_NATIVE_KERNEL_BUILD = True + + +build_cuda_native_kernel() + +from .layer_norm import MixedFusedLayerNorm as LayerNorm +from .scaled_softmax import FusedScaleMaskSoftmax +from .multihead_attention import MultiHeadAttention \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/builder.py b/colossalai/kernel/cuda_native/builder.py new file mode 100644 index 000000000..9f1d10e6e --- /dev/null +++ b/colossalai/kernel/cuda_native/builder.py @@ -0,0 +1,114 @@ +import os +import pathlib +import subprocess + +from torch.utils import cpp_extension + +# Setting this param to a list has a problem of generating different +# compilation commands (with diferent order of architectures) and +# leading to recompilation of fused kernels. Set it to empty string +# to avoid recompilation and assign arch flags explicity in +# extra_cuda_cflags below +os.environ["TORCH_CUDA_ARCH_LIST"] = "" + + +def _build_cuda_native_kernel(): + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + + # Build path + basepath = pathlib.Path(__file__).parent.absolute() + srcpath = basepath / 'csrc' + buildpath = basepath / 'build' + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=[ + '-O3', + ], + extra_include_paths=[str(srcpath / 'kernels' / 'include')], + extra_cuda_cflags=['-O3', '-gencode', 'arch=compute_70,code=sm_70', '--use_fast_math'] + + extra_cuda_flags + cc_flag, + verbose=False) + + # ============== + # Fused softmax. + # ============== + + extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + + # Upper triangular softmax. + sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', + srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] + colossal_scaled_upper_triang_masked_softmax = _cpp_extention_load_helper( + "colossal_scaled_upper_triang_masked_softmax", + sources, extra_cuda_flags) + + # Masked softmax. + sources=[srcpath / 'scaled_masked_softmax.cpp', + srcpath / 'scaled_masked_softmax_cuda.cu'] + colossal_scaled_masked_softmax = _cpp_extention_load_helper( + "colossal_scaled_masked_softmax", sources, extra_cuda_flags) + + # ================================= + # Mixed precision fused layer norm. + # ================================= + + extra_cuda_flags = ['-maxrregcount=50'] + sources = [srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu'] + colossal_layer_norm_cuda = _cpp_extention_load_helper("colossal_layer_norm_cuda", sources, + extra_cuda_flags) + + # ========================================== + # Mixed precision Transformer Encoder Layer. + # ========================================== + + extra_cuda_flags = ['-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', + '-DTHRUST_IGNORE_CUB_VERSION_CHECK'] + + sources = [srcpath / 'multihead_attention_1d.cpp'] + kernel_sources = ["cublas_wrappers.cu", + "transform_kernels.cu", + "dropout_kernels.cu", + "normalize_kernels.cu", + "softmax_kernels.cu", + "general_kernels.cu", + "cuda_util.cu"] + sources += [(srcpath / 'kernels' / cu_file) for cu_file in kernel_sources] + colossal_multihead_attention = _cpp_extention_load_helper("colossal_multihead_attention", sources, + extra_cuda_flags) + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h new file mode 100644 index 000000000..def1c2158 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/compat.h @@ -0,0 +1,13 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu new file mode 100644 index 000000000..58d26235a --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu @@ -0,0 +1,191 @@ +#include "block_reduce.h" +#include "cuda_util.h" +#include "kernels.h" +#include "ls_cub.cuh" + +ls::cub::CachingDeviceAllocator g_allocator(true); + +template +__global__ void ls_cross_entropy_fw_kernel( + const T *__restrict__ inputs, const int *__restrict__ targets, + float *__restrict__ outputs, float *__restrict__ nll_loss_outputs, + const int padding_idx, const float epsilon, const int vocab_size) { + /* step1: compute each thread's max_logit and sum_exp_logit, store in + * max_input, sum_exp_logit */ + const int block_start = blockIdx.x * vocab_size; + const int left_idx = block_start + threadIdx.x; + const int right_idx = (blockIdx.x + 1) * vocab_size; + float max_input[1] = {REDUCE_FLOAT_INF_NEG}; + float sum_logits[2] = {0.f, 0.f}; // logit and logit exp + int target_tid = targets[blockIdx.x]; + + if (target_tid == padding_idx) { + if (threadIdx.x == 0) { + nll_loss_outputs[blockIdx.x] = 0.f; + outputs[blockIdx.x] = 0.f; + } + return; + } + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); + } + blockReduce(max_input); + __shared__ float s_max_input; + if (threadIdx.x == 0) { + s_max_input = max_input[0]; + } + __syncthreads(); + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float logit = static_cast(inputs[i]) - s_max_input; + sum_logits[0] += logit; + sum_logits[1] += expf(logit); + } + + blockReduce(sum_logits); + __shared__ float s_sum_logit; + __shared__ float s_sum_exp; + if (threadIdx.x == 0) { + s_sum_logit = sum_logits[0]; + s_sum_exp = sum_logits[1]; + } + __syncthreads(); + + float eps_i = epsilon / (vocab_size - 1); + if (threadIdx.x == 0) { + // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max) + float nll_loss = logf(s_sum_exp) - + static_cast(inputs[block_start + target_tid]) + + s_max_input; + nll_loss_outputs[blockIdx.x] = nll_loss; + float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit; + outputs[blockIdx.x] = + (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss; + } +} + +template +__global__ void ls_cross_entropy_bw_kernel( + const float *__restrict__ grad_outputs, const T *__restrict__ inputs, + const int *__restrict__ targets, T *__restrict__ grad_inputs, + const int padding_idx, const float epsilon, const int vocab_size) { + /* step1: compute each thread's max_logit and sum_exp_logit, store in + * max_input, sum_exp_logit */ + const int block_start = blockIdx.x * vocab_size; + const int left_idx = block_start + threadIdx.x; + const int right_idx = (blockIdx.x + 1) * vocab_size; + float max_input[1] = {REDUCE_FLOAT_INF_NEG}; + float sum_logits[1] = {0.f}; + const float grad_out = static_cast(grad_outputs[0]); + int target_tid = targets[blockIdx.x]; + + if (target_tid == padding_idx) { + for (int i = left_idx; i < right_idx; i += blockDim.x) { + grad_inputs[i] = 0.f; + } + return; + } + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); + } + blockReduce(max_input); + __shared__ float s_max_input; + if (threadIdx.x == 0) { + s_max_input = max_input[0]; + } + __syncthreads(); + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float logit = static_cast(inputs[i]) - s_max_input; + sum_logits[0] += expf(logit); + } + + blockReduce(sum_logits); + __shared__ float s_sum_exp; + if (threadIdx.x == 0) { + s_sum_exp = sum_logits[0]; + } + __syncthreads(); + + float eps_i = epsilon / (vocab_size - 1); + float nll_weight = 1.0 - epsilon - eps_i; + + for (int i = left_idx; i < right_idx; i += blockDim.x) { + float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp; + float grad = 0; + grad += (vocab_size * prob - 1) * eps_i; + grad += prob * nll_weight; + if ((i - block_start) == target_tid) { + grad -= nll_weight; + } + grad_inputs[i] = grad_out * grad; + } +} + +template +void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, + float *outputs_ptr, float *nll_loss_ptr, + float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, + const int seq_len, const int vocab_size, + cudaStream_t stream) { + int grid_dim = batch_size * seq_len; + float *nll_loss_buffer = loss_buffer + grid_dim; + ls_cross_entropy_fw_kernel<<>>( + inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx, + epsilon, vocab_size); + + int num_items = grid_dim; + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + loss_buffer, outputs_ptr, + num_items, stream)); + CHECK_GPU_ERROR( + g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + loss_buffer, outputs_ptr, + num_items, stream)); + CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + nll_loss_buffer, nll_loss_ptr, + num_items, stream)); + CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage)); +} + +template void launch_cross_entropy_fw( + const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr, float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template void launch_cross_entropy_fw<__half>( + const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr, float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template +void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr, + const int padding_idx, const float epsilon, + const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream) { + int grid_dim = batch_size * seq_len; + ls_cross_entropy_bw_kernel<<>>( + grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx, + epsilon, vocab_size); +} + +template void launch_cross_entropy_bw( + const float *grad_outputs_ptr, const float *inputs_ptr, + const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template void launch_cross_entropy_bw<__half>( + const float *grad_outputs_ptr, const __half *inputs_ptr, + const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx, + const float epsilon, const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu new file mode 100644 index 000000000..68be1f6d7 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu @@ -0,0 +1,87 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed +*/ +#include "cublas_wrappers.h" + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const float *A, + const float *B, float *C, cublasGemmAlgo_t algo) { + cublasStatus_t status = + cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha, + (const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k, + (const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n, + (const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const __half *A, + const __half *B, __half *C, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmEx( + handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A, + CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F, + (transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C, + CUDA_R_16F, m, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const float *A, const float *B, float *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmStridedBatchedEx( + handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F, + (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F, + (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C, + batch, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, " + "error: %d) \n", + batch, m, n, k, (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const __half *A, const __half *B, __half *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, cublasGemmAlgo_t algo) { + cublasStatus_t status = cublasGemmStridedBatchedEx( + handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F, + (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F, + (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C, + batch, CUDA_R_32F, algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, n, k, (int)status); + return EXIT_FAILURE; + } + + return 0; +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu new file mode 100644 index 000000000..9a6a8ebc3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu @@ -0,0 +1,169 @@ +#include +#include + +#include "cuda_util.h" + +/* GPU function guard */ +std::string _cudaGetErrorString(cudaError_t error) { + return cudaGetErrorString(error); +} + +std::string _cudaGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return "CUBLAS_UNKNOW"; +} + +template +void check_gpu_error(T result, char const *const func, const char *const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" + + std::to_string(line) + + "): " + (_cudaGetErrorString(result)) + "\n"); + } +} + +template void check_gpu_error(cudaError_t result, + char const *const func, + const char *const file, + int const line); +template void check_gpu_error(cublasStatus_t result, + char const *const func, + const char *const file, + int const line); + +template +void print_vec(const T *outv, std::string outn, int num_output_ele) { + std::cout << outn << ": "; + std::vector hout(num_output_ele, (T)0); + cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T), + cudaMemcpyDeviceToHost); + for (int i = 0; i < num_output_ele; i++) { + std::cout << hout[i] << ", "; + } + std::cout << std::endl; +} + +template <> +void print_vec<__half>(const __half *outv, std::string outn, + int num_output_ele) { + std::cout << outn << ": "; + std::vector<__half> hout(num_output_ele, (__half)0.f); + cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half), + cudaMemcpyDeviceToHost); + for (int i = 0; i < num_output_ele; i++) { + std::cout << __half2float(hout[i]) << ", "; + } + std::cout << std::endl; +} + +template void print_vec(const float *outv, std::string outn, + int num_output_ele); + +template void print_vec(const int *outv, std::string outn, + int num_output_ele); + +template void print_vec<__half>(const __half *outv, std::string outn, + int num_output_ele); + +template +T *cuda_malloc(size_t ele_num) { + size_t byte_size = ele_num * sizeof(T); + T *pdata = nullptr; + CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size)); + return pdata; +} + +template float *cuda_malloc(size_t ele_num); + +template __half *cuda_malloc<__half>(size_t ele_num); + +template uint8_t *cuda_malloc(size_t ele_num); + +void cuda_free(void *pdata) { + if (pdata != nullptr) { + cudaFree(pdata); + } +} + +template +struct _isnan { + __device__ bool operator()(T a) const { return isnan(a); } +}; + +template <> +struct _isnan<__half> { + __device__ bool operator()(const __half a) const { return __hisnan(a); } +}; + +template +struct _isinf { + __device__ bool operator()(T a) const { return isinf(a); } +}; + +template <> +struct _isinf<__half> { + __device__ bool operator()(const __half a) const { return __hisinf(a); } +}; + +template +void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, + std::string file, int line, cudaStream_t stream) { + // check_nan_inf = 0 for checking nan + // check_nan_inf = 1 for checking inf + bool res = false; + std::string msg = file + "(" + std::to_string(line) + "): "; + if (check_nan_inf) { + msg += "nan."; + res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, + data_ptr + dsize, _isnan(), false, + thrust::logical_or()); + } else { + msg += "inf."; + res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, + data_ptr + dsize, _isinf(), false, + thrust::logical_or()); + } + if (res) { + throw std::runtime_error(msg); + } + std::cout << msg << " [check pass]." << std::endl; +} + +template void check_nan_inf(const float *data_ptr, int dsize, + bool check_nan_inf, std::string file, + int line, cudaStream_t stream); + +template void check_nan_inf<__half>(const __half *data_ptr, int dsize, + bool check_nan_inf, std::string file, + int line, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu new file mode 100644 index 000000000..7d314c11e --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -0,0 +1,1001 @@ +#include +#include + +#include "kernels.h" + +#include + +namespace cg = cooperative_groups; + +curandStatePhilox4_32_10_t *curandstate; + +/** + * @brief element-wise activation function on device, like Relu, Gelu + * + * @tparam enum class ActivationType, kRelu, kGelu + * @tparam input type + * @param any shape of float and __half2 + * @return same shape and type with input + */ +template +__forceinline__ __device__ T activation_kernel(T x); + +template <> +__device__ float activation_kernel(float x) { + float cdf = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template <> +__device__ __half2 +activation_kernel(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +template <> +__device__ float activation_kernel(float x) { + return fmaxf(x, 0); +} + +template <> +__device__ __half2 +activation_kernel(__half2 x) { + return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), + fmaxf(0.f, __half2float(x.y))); +} + +/** + * @brief element-wise activation backward function on device + * + * @tparam enum class ActivationType + * @tparam input type + * @param any shape of float and __half2 + * @return same shape of input + */ +template +__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * (dg1 + dg2 + dg3); +} + +template <> +__device__ __half activation_bwd_kernel( + __half grad, __half x_half) { + float x = __half2float(x_half); + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * __float2half(dg1 + dg2 + dg3); +} + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + return x > 0.f ? grad : 0.f; +} + +template <> +__device__ __half +activation_bwd_kernel(__half grad, __half x) { + const __half half_zero = __float2half(0.f); + return x > half_zero ? grad : half_zero; +} + +template <> +__device__ __half2 activation_bwd_kernel( + __half2 grad2, __half2 x_half2) { + const __half half_zero = __float2half(0.f); + return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, + x_half2.y > half_zero ? grad2.y : half_zero); +} + +/** + * @brief init curand states in global memory + * + * @thread grid_dim * block*dim to suuport any size of states + * @param state persistant curand states + * @param seed seed to init states + * @return void + */ +__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, + int seed) { + /* Each thread gets same seed, a different sequence + number, no offset */ + int id = threadIdx.x + blockIdx.x * blockDim.x; + curand_init(seed, id, 0, &state[id]); +} + +void launch_curand_init(int total_count, int dim, cudaStream_t stream) { + cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); + int grid_dim = total_count >> 9; + curand_init_kernel<<>>( + curandstate, std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); +} + +/** + * @brief element-wise dropout, store dropped position in mask, it's not + * in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out any size of float and __half + * @param in same with out + * @param mask uint8 type, same size with out + * @param seed seed to curand + * @return void + */ +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + float *__restrict__ out, + const float *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + + float4 input4 = data4[i]; + float4 res4; + res4.x = input4.x * scale * m[0]; + res4.y = input4.y * scale * m[1]; + res4.z = input4.z * scale * m[2]; + res4.w = input4.w * scale * m[3]; + out4[i] = res4; +} + +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + __half *__restrict__ out, + const __half *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + outs_float4[i] = out_float4; +} + +/** + * @brief element-wise dropout backward with dropout mask, it's + * not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param in any size of float and __half + * @param mask uint8 type, same size with in + * @return void + */ +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + float *out, const float *in, + const uint8_t *__restrict__ mask) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *in4 = reinterpret_cast(in); + const uint32_t *mask4 = reinterpret_cast(mask); + + uint32_t *m4 = reinterpret_cast(m); + m4[0] = mask4[i]; + + float4 input4 = in4[i]; + float4 res4; + res4.x = input4.x * scale * static_cast(m[0]); + res4.y = input4.y * scale * static_cast(m[1]); + res4.z = input4.z * scale * static_cast(m[2]); + res4.w = input4.w * scale * static_cast(m[3]); + out4[i] = res4; +} + +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + __half *out, const __half *in, + const uint8_t *__restrict__ mask) { + const __half scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + float4 *out4 = reinterpret_cast(out); + const float4 *vals_float4 = reinterpret_cast(in); + const uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + uint64_t *m8 = reinterpret_cast(m); + m8[0] = mask8[i]; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + out4[i] = out_float4; +} + +template <> +void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, + int total_count, float ratio, cudaStream_t stream, + bool backward) { + int grid_dim = total_count >> 12; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +template <> +void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, + int total_count, float ratio, + cudaStream_t stream, bool backward) { + int grid_dim = total_count >> 13; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +/** + * @brief fused bias, dropout, and residual at the end of Attention and FFN, + * store dropped position in mask, it's not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param residual [batch_size, seq_len, hidden_size], float and __half + * @param seed seed to curand + * @param hidden_size hidden size + * @return void + */ +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const float *__restrict__ residual, + const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 output4; + + output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; + output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; + output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; + output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; + + out4[i] = output4; +} + +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const __half *__restrict__ residual, + const int seed, const int hidden_size) { + const __half scale = 1. / (1. - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = static_cast(rand.x > ratio); + m[5] = static_cast(rand.y > ratio); + m[6] = static_cast(rand.z > ratio); + m[7] = static_cast(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = m8[0]; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + const __half2 *res_half2 = reinterpret_cast(&res4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = + __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); + out_half2[1] = + __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); + out_half2[2] = + __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); + out_half2[3] = + __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_res_bias(float *out, const float *vals, + uint8_t *mask, const float *bias, + const float *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 12; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, + uint8_t *mask, const __half *bias, + const __half *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 13; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias and dropout backward at the end of Attention and FFN + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, float *__restrict__ in_grad, + float *__restrict__ bias_grad, const float *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + // every block generate 8 bias result + __shared__ float tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + float val = out_grad[idx]; + val *= scale * static_cast(mask[idx]); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + float sum = 0; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, __half *__restrict__ in_grad, + __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); + __shared__ __half2 tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); + const __half2 *out_grad2 = reinterpret_cast(out_grad); + __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + __half2 local_sum = __float2half2_rn(0.f); + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + __half2 val = out_grad2[idx]; + __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); + val *= scale * m2; + local_sum += val; + in_grad2[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + __half2 sum = __float2half2_rn(0.f); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad2[pos] = tile[0][threadIdx.x]; + } +} + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template <> +void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, + const __half *out_grad, const uint8_t *mask, + int row_size, int dim, float ratio, + cudaStream_t stream) { + dim >>= 1; + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, + const float *out_grad, + const uint8_t *mask, int row_size, + int dim, float ratio, + cudaStream_t stream); + +/** + * @brief fused bias, activation, and dropout at the end of first ffn + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @tparam act_type activation function, like kRelu, kGelu + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param seed seed to curand + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 output4; + + output4.x = + activation_kernel(input4.x + b4.x) * scale * m[0]; + output4.y = + activation_kernel(input4.y + b4.y) * scale * m[1]; + output4.z = + activation_kernel(input4.z + b4.z) * scale * m[2]; + output4.w = + activation_kernel(input4.w + b4.w) * scale * m[3]; + + out4[i] = output4; +} + +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2( + activation_kernel(__hadd2(val_half2[0], b_half2[0])), + scale_mask_1); + out_half2[1] = __hmul2( + activation_kernel(__hadd2(val_half2[1], b_half2[1])), + scale_mask_2); + out_half2[2] = __hmul2( + activation_kernel(__hadd2(val_half2[2], b_half2[2])), + scale_mask_3); + out_half2[3] = __hmul2( + activation_kernel(__hadd2(val_half2[3], b_half2[3])), + scale_mask_4); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias, activation, and dropout backward + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @tparam act_type kRelu + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_bwd_kernel( + const int row_size, const float ratio, T *in_grad, + T *__restrict__ bias_grad, const T *__restrict__ input, + const T *__restrict__ bias, const T *out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + + int stride = hidden_size * WARP_SIZE; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + if (col_idx < hidden_size) { + for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { + float val = out_grad[idx]; + float in = input[idx]; + float b = bias[idx % hidden_size]; + val = activation_bwd_kernel( + val * scale * static_cast(mask[idx]), in + b); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + float sum = tile[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + __syncthreads(); + + if (threadIdx.y == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +// @brief fused bias, activation, and dropout backward +// It is deprecated for precision reason. Keep it for future optimization. +// +// template +// __global__ void ls_dropout_act_bias_bwd_kernel( +// const int row_size, const float ratio, __half * in_grad, +// __half *__restrict__ bias_grad, const __half *__restrict__ input, const +// __half *__restrict__ bias, const __half * out_grad, const uint8_t +// *__restrict__ mask, const int hidden_size) { +// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); +// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; + +// cg::thread_block b = cg::this_thread_block(); +// cg::thread_block_tile g = cg::tiled_partition(b); + +// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); +// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); +// const __half2 *out_grad2 = reinterpret_cast(out_grad); +// const __half2 *input2 = reinterpret_cast(input); +// const __half2 *bias2 = reinterpret_cast(bias); + +// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + +// int stride = hidden_size * WARP_SIZE; +// __half2 local_sum = __float2half2_rn(0.f); + +// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); +// if (col_idx < hidden_size) { +// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { +// __half2 val = out_grad2[idx]; +// __half2 in2 = input2[idx]; +// __half2 b2 = bias2[idx % hidden_size ]; +// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); +// val = activation_bwd_kernel(val * scale +// * +// m2, +// in2+b2); +// local_sum += val; +// in_grad2[idx] = val; +// idx += stride; +// } +// } + +// tile[threadIdx.x][threadIdx.y] = local_sum; +// __syncthreads(); +// __half2 sum = tile[threadIdx.y][threadIdx.x]; +// __syncthreads(); + +// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + +// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; +// __syncthreads(); + +// if (threadIdx.y == 0) { +// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); +// bias_grad2[pos] = tile[0][threadIdx.x]; +// } +// } + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + ls_dropout_act_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); +} + +// template <> +// void launch_ls_dropout_act_bias_bwd( +// __half *in_grad, __half *bias_grad,const __half *input, const __half +// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int +// dim, float ratio, cudaStream_t stream) { +// dim >>= 1; +// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); +// dim3 block_dim(WARP_SIZE, WARP_SIZE); +// ls_dropout_act_bias_bwd_kernel +// <<>>(row_size, ratio, in_grad, +// bias_grad, +// input, bias,out_grad, mask, dim); +// } + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu new file mode 100644 index 000000000..e37bc3d04 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu @@ -0,0 +1,232 @@ +#include "kernels.h" + +#include + +namespace cg = cooperative_groups; + +/** +@brief: fuse_transpose_bias +Calculate the sum of elements in each column of the matrix. + +@thread +gridDim.x = ceil(cols / WARP_SIZE) +blockDim.x = WARP_SIZE +blockDim.y = WARP_SIZE + +@param +inp: [rows, cols] +out: [cols] +rows: the number of rows in the matrix +cols: the number of cols in the matrix +*/ +template +__global__ void column_sum_reduce(const T *__restrict__ inp, + T *__restrict__ out, int rows, int cols) { + __shared__ float tile[WARP_SIZE][WARP_SIZE]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + int y_stride = cols * WARP_SIZE; + float localSum = 0; + + // Loop across matrix row + // TODO: optimize to log complexity + if (idx < cols) { + int offset = flat_2dim(threadIdx.y, idx, cols); + for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + // The sum of a row in tile is equal to the sum of a col in original matrix + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + // The change of threadIdx.x is continuous + float sum = tile[threadIdx.y][threadIdx.x]; + + __syncthreads(); + + // Calculate the sum of a row in tile + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); + if (pos < cols) out[pos] = sum; + } +} + +// [r, c] -> [c] +template <> +void launch_fuse_transpose_bias_kernel(const float *inp, float *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce + <<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce<__half> + <<>>(inp, out, rows, cols); +} + +/** +@brief: fused_add2 +Add two matrix inp1 and inp2 to out. + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +inp1: [batch_size, seq_len, hidden_dim] +inp2: [batch_size, seq_len, hidden_dim] +out: [batch_size, seq_len, hidden_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +*/ +template +__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, + int hidden_dim); + +template <> +__global__ void fused_add2_kernel(float *out, const float *inp1, + const float *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + val.x = vinp1.x + vinp2.x; + val.y = vinp1.y + vinp2.y; + val.z = vinp1.z + vinp2.z; + val.w = vinp1.w + vinp2.w; + out_4[offset + i] = val; + } +} + +template <> +__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, + const __half *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); + __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); + __half2 *h2_val = reinterpret_cast<__half2 *>(&val); + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); + h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); + h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); + h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); + out_4[offset + i] = val; + } +} + +//[b, s, h] -> [b, s, h] +template <> +void launch_fused_add2(float *out, const float *inp1, const float *inp2, + int batch_size, int seq_len, int hidden_dim, + cudaStream_t &stream) { + hidden_dim >>= 2; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template <> +void launch_fused_add2<__half>(__half *out, const __half *inp1, + const __half *inp2, int batch_size, int seq_len, + int hidden_dim, cudaStream_t &stream) { + hidden_dim >>= 3; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template +__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, + int sz0, int sz2, int sz1_1, int sz1_2) { + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); + if (idx >= nele) { + return; + } + float4 *dst_ptr = (float4 *)output + idx; + int idx2 = idx % sz2; + idx = idx / sz2; + int idx1 = idx % (sz1_1 + sz1_2); + int idx0 = idx / (sz1_1 + sz1_2); + float4 *src_ptr = nullptr; + int sz1 = 0; + if (idx1 < sz1_1) { + sz1 = sz1_1; + src_ptr = (float4 *)inp1; + } else { + idx1 -= sz1_1; + sz1 = sz1_2; + src_ptr = (float4 *)inp2; + } + src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); + dst_ptr[0] = src_ptr[0]; +} + +template <> +void launch_concat3_dim1(const float *inp1, const float *inp2, + float *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 2; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} + +template <> +void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, + __half *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 3; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h new file mode 100644 index 000000000..38103c173 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h @@ -0,0 +1,312 @@ +/* Copyright 2021 The LightSeq Team + Copyright Tencent/TurboTransformers + This block_reduce_n is adapted from Tencent/TurboTransformers +*/ +#pragma once +#include +#include +#include + +enum class ReduceType { kMax = 0, kSum }; +const unsigned int WARP_REDUCE_MASK = 0xffffffff; +const float REDUCE_FLOAT_INF_NEG = -100000000.f; +const float REDUCE_FLOAT_INF_POS = 100000000.f; +const unsigned int WARP_REDUCE_SIZE = 32; + +template +__forceinline__ __device__ T warpReduceSum(T val) { + for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) + val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__forceinline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) shared[wid] = val; + __syncthreads(); + + val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; + val = warpReduceSum(val); + return val; +} + +template +__inline__ __device__ void blockReduce(float *pval); + +// use template to make code more concise +template +__inline__ __device__ void warpReduce(float *pval); + +// static +template <> +__inline__ __device__ void warpReduce(float *pval) { + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32)); + *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32)); +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp; +#define WarpReduceMaxOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval) = max(val0_tmp, *(pval)); \ + *(pval + 1) = max(val1_tmp, *(pval + 1)); + + WarpReduceMaxOneStep(16, 32); + WarpReduceMaxOneStep(8, 32); + WarpReduceMaxOneStep(4, 32); + WarpReduceMaxOneStep(2, 32); + WarpReduceMaxOneStep(1, 32); +#undef WarpReduceMaxOneStep +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32); + *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32); +} + +/* + * Unorll for loop for warpreduce to + * imporve instruction issue efficiency + * ElemX means there are X numbers to be summed + */ + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp; +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp + + WarpReduceSumOneStep(16, 32); + WarpReduceSumOneStep(8, 32); + WarpReduceSumOneStep(4, 32); + WarpReduceSumOneStep(2, 32); + WarpReduceSumOneStep(1, 32); + +#undef WarpReduceSumOneStep +} + +template <> +__inline__ __device__ void warpReduce(float *pval) { + float val0_tmp, val1_tmp, val2_tmp, val3_tmp; +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ + val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp; \ + *(pval + 2) += val2_tmp; \ + *(pval + 3) += val3_tmp + + WarpReduceSumOneStep(16, 32); + WarpReduceSumOneStep(8, 32); + WarpReduceSumOneStep(4, 32); + WarpReduceSumOneStep(2, 32); + WarpReduceSumOneStep(1, 32); +#undef WarpReduceSumOneStep +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 2; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 4; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = 0.f; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} + +template <> +__inline__ __device__ void blockReduce(float *pval) { + const int num = 1; + static __shared__ float shared[num][32]; + int lane_id = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduce(pval); + + if (lane_id == 0) { +#pragma unroll + for (int i = 0; i < num; ++i) { + shared[i][wid] = *(pval + i); + } + } + __syncthreads(); + + if (threadIdx.x < (blockDim.x >> 5)) { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = shared[i][lane_id]; + } + } else { +#pragma unroll + for (int i = 0; i < num; ++i) { + *(pval + i) = REDUCE_FLOAT_INF_NEG; + } + } + warpReduce(pval); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h new file mode 100644 index 000000000..f7d75f38c --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include +#include + +#include "cuda_util.h" + +class Context { + public: + Context() : _stream(nullptr) { + CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); + } + + virtual ~Context() {} + + static Context &Instance() { + static Context _ctx; + return _ctx; + } + + void set_stream(cudaStream_t stream) { + _stream = stream; + CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream)); + } + + cudaStream_t get_stream() { return _stream; } + + cublasHandle_t get_cublashandle() { return _cublasHandle; } + + private: + cudaStream_t _stream; + cublasHandle_t _cublasHandle; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h new file mode 100644 index 000000000..f4e9befc6 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include + +#include + +#include "cuda_util.h" + +template +class CrossEntropyLayer { + public: + CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); + + virtual ~CrossEntropyLayer(); + + void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr, + float *nll_loss_ptr); + + void Backward(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr); + + void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); + + private: + void allocate_mem_buffer() { + // allocate local gpu memory + _loss_buffer = cuda_malloc(_max_batch_tokens * 2); + } + + void free_mem_buffer() { + // free local gpu memory + cuda_free(_loss_buffer); + } + + const int _padding_idx; + const float _epsilon; + const int _max_batch_tokens; + + size_t _batch_size; + size_t _seq_len; + size_t _vocab_size; + + float *_loss_buffer; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h new file mode 100644 index 000000000..7ebb9ce48 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const float *A, + const float *B, float *C, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); + +int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float *alpha, const float *beta, const __half *A, + const __half *B, __half *C, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); + +int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, + const float *alpha, const float *beta, + const float *A, const float *B, float *C, + cublasOperation_t op_A, cublasOperation_t op_B, + int stride_A, int stride_B, int stride_C, + int batch, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); + +int cublas_strided_batched_gemm( + cublasHandle_t handle, int m, int n, int k, const float *alpha, + const float *beta, const __half *A, const __half *B, __half *C, + cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, + int stride_C, int batch, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h new file mode 100644 index 000000000..1595257be --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template +void check_gpu_error(T result, char const *const func, const char *const file, + int const line); + +#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__) + +template +void print_vec(const T *outv, std::string outn, int num_output_ele); + +template +T *cuda_malloc(size_t ele_num); + +void cuda_free(void *pdata); + +template +void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, + std::string file, int line, cudaStream_t stream); + +#define CHECK_NAN_INF(ptr, size, stream) \ + check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ + check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h new file mode 100644 index 000000000..336bbacc9 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include + +#include "kernels.h" + +template +class Dropout { + public: + struct Config { + float ratio; + bool training; + + Config(float r) : ratio(r), training(true) {} + float RATIO() const { return training ? ratio : 0.0; } + }; + + Dropout(const Config &config, size_t max_ele_num) + : _config(config), _mask(nullptr) { + _mask = cuda_malloc(max_ele_num); + } + + virtual ~Dropout() { cuda_free(_mask); } + + // after attention softmax + void dropout(T *output, const T *input, int count, cudaStream_t stream, + bool bwd = false) { + launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, + bwd); + } + + void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { + launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), + stream, true); + } + + // transformer layer's postprocessing dropout, after attn or ffn module, + // before residual add. + void bias_dropout_residual(T *output, const T *input, const T *residual, + const T *bias, int rows, int cols, + cudaStream_t stream) { + launch_ls_dropout_res_bias(output, input, _mask, bias, residual, + rows * cols, cols, _config.RATIO(), stream); + } + + void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, + int rows, int cols, cudaStream_t stream) { + launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, + _config.RATIO(), stream); + } + + // dropout inside ffn. + void bias_act_dropout(T *output, const T *input, const T *bias, int rows, + int cols, std::string activation_fn, + cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, + const T *bias, int rows, int cols, + std::string activation_fn, cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + private: + uint8_t *_mask; + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h new file mode 100644 index 000000000..ec963259f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h @@ -0,0 +1,68 @@ +#pragma once + +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed +*/ +#include +#include +#include + +#include + +#include "cublas_wrappers.h" +#include "kernels.h" + +template +class FeedForward { + public: + struct Config { + int outputSize; + int inputSize; + std::array gemm_algos; + Config(int outputs, int inputs) + : outputSize(outputs), + inputSize(inputs), + gemm_algos(std::array({99, 99, 99})) {} + }; + + FeedForward(Config config) : config_(config) {} + + ~FeedForward() {} + + void Forward(int bsz, const T *input_ptr, const T *weights, T *out, + cublasHandle_t &_cublasHandle) { + float alpha = T(1.); + float beta = T(0.); + + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize, + bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, + out, cublasGemmAlgo_t(config_.gemm_algos[0])); + } + void Backward(int bsz, const T *out_grad, const T *input_ptr, + const T *weights, T *weights_grad, T *bias_grad, + cublasHandle_t &_cublasHandle, cudaStream_t &stream, + T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr, + bool compute_bias = true) { + float alpha = (T)1.0, beta = (T)0.0; + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, + config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, + weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1])); + + cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize, + bsz, config_.outputSize, &alpha, &beta, weights, out_grad, + inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2])); + if (compute_bias) { + launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, + config_.outputSize, stream); + } + } + + void reset_size(int outputSize, int inputSize) { + config_.outputSize = outputSize; + config_.inputSize = inputSize; + } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h new file mode 100644 index 000000000..109aca48c --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h @@ -0,0 +1,274 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define MAX_THREADS 1024 +#define WARP_SIZE 32 + +enum class ActivationType { kRelu, kGelu }; + +void launch_curand_init(int total_count, int dim, cudaStream_t stream); + +template +void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int batch_size, + int hidden_dim, cudaStream_t stream); + +template +void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, const T *vars, const T *means, int batch, + int hidden_dim, cudaStream_t stream[2]); + +template +void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads, + int from_len, int to_len, bool mask_future, + cudaStream_t stream); + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream); + +// [b, s, h] -> [b, nh, s, ad] +template +void launch_transform_0213(T *output, const T *vals, int batch_size, + int seq_length, int hidden_dim, int nhead, + cudaStream_t stream); + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template +void launch_bias_add_transform_20314(T *output, const T *input, const T *bias, + int dim_0, int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream); + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template +void launch_transform4d_0213(T *output, const T *vals, int batch_size, + int seq_len, int hidden_dim, int nhead, + int trans_count, cudaStream_t stream); + +template +void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count, + float ratio, cudaStream_t stream, bool backward = false); + +template +void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask, + const T *bias, const T *residual, + int total_count, int dim, float ratio, + cudaStream_t stream); + +template +void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask, + const T *bias, int total_count, int dim, + float ratio, cudaStream_t stream); + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template +void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols, + cudaStream_t stream); + +void launch_param_update(const float *input, __half *output, int size, + cudaStream_t stream); + +template +void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0, + int sz2, int sz1_1, int sz1_2, cudaStream_t stream); + +template +void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size, + int seq_len, int hidden_size, cudaStream_t &stream); + +template +void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, + float *outputs_ptr, float *nll_loss_ptr, + float *loss_buffer, const int padding_idx, + const float epsilon, const int batch_size, + const int seq_len, const int vocab_size, + cudaStream_t stream); + +template +void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, + const int *targets_ptr, T *grad_inputs_ptr, + const int padding_idx, const float epsilon, + const int batch_size, const int seq_len, + const int vocab_size, cudaStream_t stream); + +template +void launch_lookup_scale_pos_dropout( + T *output, const int *input, const T *embeddings, const T *pos_embeddings, + uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, + int padding_idx, float dropout_ratio, int step, cudaStream_t &stream); + +template +void launch_d_lookup_scale_pos_dropout( + T *grad_embeddings, const T *grad_output, const int *input, + const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, + int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream); + +/* Convert 2-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) { + return id1 * dim2 + id2; +} + +/* Convert 3-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, + int dim2, int dim3) { + return id1 * dim2 * dim3 + id2 * dim3 + id3; +} + +/* Convert 4-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, + int id4, int dim2, int dim3, + int dim4) { + // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; + int res = id4; + + int ld = dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert 5-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3, + int id4, int id5, int dim2, + int dim3, int dim4, + int dim5) { + // return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) + + // id4*dim5 + dim5; + int res = id5; + + int ld = dim5; + res += id4 * ld; + + ld *= dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert 6-dim tensor index into vector index */ +__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, + int id4, int id5, int id6, + int dim2, int dim3, int dim4, + int dim5, int dim6) { + // return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) + + // id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6; + int res = id6; + + int ld = dim6; + res += id5 * ld; + + ld *= dim5; + res += id4 * ld; + + ld *= dim4; + res += id3 * ld; + + ld *= dim3; + res += id2 * ld; + + ld *= dim2; + res += id1 * ld; + + return res; +} + +/* Convert vector index to 6-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_6dim( + int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, + int *id1, int *id2, int *id3, int *id4, int *id5) { + *id5 = src % dim5; + src /= dim5; + + *id4 = src % dim4; + src /= dim4; + + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 5-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, + int dim2, int dim3, + int dim4, int *id0, + int *id1, int *id2, + int *id3, int *id4) { + *id4 = src % dim4; + src /= dim4; + + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 4-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, + int dim2, int dim3, + int *id0, int *id1, + int *id2, int *id3) { + *id3 = src % dim3; + src /= dim3; + + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 3-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, + int dim2, int *id0, + int *id1, int *id2) { + *id2 = src % dim2; + src /= dim2; + + *id1 = src % dim1; + *id0 = src / dim1; +} + +/* Convert vector index to 2-dim tensor index */ +__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1, + int *id0, int *id1) { + *id1 = src % dim1; + *id0 = src / dim1; +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh new file mode 100644 index 000000000..4f65e7b54 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh @@ -0,0 +1,12 @@ +// copied from https://github.com/dmlc/dgl/pull/2758 +#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ +#define DGL_ARRAY_CUDA_DGL_CUB_CUH_ + +#define CUB_NS_PREFIX namespace ls { +#define CUB_NS_POSTFIX } +#include "cub/cub.cuh" +#include "cub/util_allocator.cuh" +#undef CUB_NS_POSTFIX +#undef CUB_NS_PREFIX + +#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h new file mode 100644 index 000000000..22e16fe90 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Normalize_Layer { + public: + struct Config { + uint32_t hidden_dim; + bool use_mean; + Config(uint32_t hidden_dim, bool use_mean = false) + : hidden_dim(hidden_dim), use_mean(use_mean) {} + }; + + Normalize_Layer(Config config, size_t max_rows) + : config_(config), vars_(nullptr), means_(nullptr) { + vars_ = cuda_malloc(max_rows); + if (config_.use_mean) { + means_ = cuda_malloc(max_rows); + } + } + + ~Normalize_Layer() { + cuda_free(vars_); + cuda_free(means_); + } + + void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, + int batch_size, cudaStream_t stream) { + launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, + config_.hidden_dim, stream); + } + + /* + residual_grad, inp_or_out, betta should be treated carefully. + inp_or_out = input if use_mean else output + residual_grad, betta can be nullptr. + residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln + betta are only used to compute xhat, + (use_mean == false) ^ (betta == nullptr) should be true + */ + void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, int batch_size, cudaStream_t stream[2]) { + launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, + inp_or_out, gamma, betta, vars_, means_, batch_size, + config_.hidden_dim, stream); + } + + inline bool use_mean() const { return config_.use_mean; } + + private: + Config config_; + T *vars_; + T *means_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h new file mode 100644 index 000000000..978c72fed --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Softmax { + public: + struct Config { + size_t nhead; + Config(size_t nhead) : nhead(nhead) {} + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, + int to_len, cudaStream_t &stream, bool mask_future = true) { + launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, + to_len, mask_future, stream); + } + + void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, + int to_len, cudaStream_t stream) { + launch_attn_softmax_bw(out_grad, soft_out, + batch_size * config_.nhead * from_len, to_len, + stream); + } + + void reset_size(size_t nhead) { + config_.nhead = nhead; + } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h new file mode 100644 index 000000000..3120660b9 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h @@ -0,0 +1,99 @@ +/* Copyright 2021 The LightSeq Team + Copyright Microsoft DeepSpeed + This file is adapted from Microsoft DeepSpeed +*/ +#pragma once + +#include +#include +#include + +#include + +#include "cublas_wrappers.h" + +template +class StridedBatchGemm { + public: + struct Config { + int m; + int n; + int k; + float alpha; + float beta; + cublasOperation_t op_A; + cublasOperation_t op_B; + std::array gemm_algos; + + Config(float param_alpha, float param_beta, cublasOperation_t opA, + cublasOperation_t opB) + : alpha(param_alpha), + beta(param_beta), + op_A(opA), + op_B(opB), + gemm_algos(std::array({99, 99, 99})) {} + void SetConfig(int mm, int nn, int kk) { + m = mm; + n = nn; + k = kk; + } + }; + + StridedBatchGemm(const Config &config) : _config(config) {} + + virtual ~StridedBatchGemm() {} + + void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b, + cublasHandle_t handle) { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + + cublas_strided_batched_gemm( + handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, + _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, + stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0])); + } + + void Backward(int bsz, const T *d_output, const T *_buffer_a, + const T *_buffer_b, cublasHandle_t handle, + T *inpGradA = nullptr, T *inpGradB = nullptr) { + int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); + int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); + + int stride_a = mb * _config.n; + int stride_b = _config.n * kb; + int stride_c = _config.m * _config.k; + + // B need to transpose. + cublasOperation_t op_b = + (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + // Calculate d_A. + cublas_strided_batched_gemm( + handle, mb, kb, _config.n, &_config.alpha, &_config.beta, + (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), + (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, + CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, + cublasGemmAlgo_t(_config.gemm_algos[1])); + + // A need to transpose. + cublasOperation_t op_a = + (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + stride_a = _config.m * _config.k; + stride_b = _config.m * _config.n; + stride_c = _config.n * _config.k; + + // Calculate d_B. + cublas_strided_batched_gemm( + handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, + _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, + stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2])); + } + + inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } + + private: + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu new file mode 100644 index 000000000..d992e7e14 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -0,0 +1,1160 @@ +#include "block_reduce.h" +#include "kernels.h" + +#include + +namespace cg = cooperative_groups; +const float LN_EPSILON = 1e-8f; +#define TILE_DIM 32 + +template +__forceinline__ __device__ T add_eps(T x) { + return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); +} + +/** +@brief: ker_layer_norm +Standard layer normalization. +It will not only output the layer norm result, + but also outputs variance. + may also output means, depends on whether + the means argument is nullptr + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +ln_res: [batch_size* seq_len, hidden_size], ln result. +vars: [batch_size* seq_len], variance per token +means: [batch_size* seq_len], means per token, can be nullput +inp: [batch_size * seq_len, hidden_size], ln input. +scale: [hidden_size], ln scale +bias: [hidden_size], ln bias +*/ +template +__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val = inp_f4[idx]; + l_sum += val.x + val.y + val.z + val.w; + l_square_sum += + val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 4.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 vscale = __ldg((const float4 *)scale + idx); + float4 vbias = __ldg((const float4 *)bias + idx); + float4 val = inp_f4[idx]; + val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; + val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; + val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; + val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; + output_f4[idx] = val; + } +} + +template <> +__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, + __half *means, const __half *inp, + const __half *scale, const __half *bias, + int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 val_f2 = __half22float2(val_h2[i]); + l_sum += val_f2.x + val_f2.y; + l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; + } + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 8.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + // load scale, bias, input + float4 scale_f4 = __ldg((const float4 *)scale + idx); + __half2 *scale_h2 = (__half2 *)(&scale_f4); + float4 bias_f4 = __ldg((const float4 *)bias + idx); + __half2 *bias_h2 = (__half2 *)(&bias_f4); + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); + +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 scale_f2 = __half22float2(scale_h2[i]); + float2 bias_f2 = __half22float2(bias_h2[i]); + float2 val_f2 = __half22float2(val_h2[i]); + val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; + val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; + val_h2[i] = __float22half2_rn(val_f2); + } + output_f4[idx] = val_f4; + } +} + +// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half *bias, +// int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 2; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_h2[i] = __float22half2_rn(val_f2); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x; +// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y; +// val_h2_1[i] = __float22half2_rn(val_f2_1); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// } +// } + +// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half *bias, +// int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// float4 val_f4_2 = inp_f4[idx+2]; +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + val_f2_2.y + val_f2_3.x + val_f2_3.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; +// l_square_sum += val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// l_square_sum += val_f2_2.x * val_f2_2.x + val_f2_2.y * val_f2_2.y; +// l_square_sum += val_f2_3.x * val_f2_3.x + val_f2_3.y * val_f2_3.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 4; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); +// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); +// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); +// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); +// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); +// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); +// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// float4 val_f4_2 = inp_f4[idx+2]; +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); +// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); +// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x; +// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y; +// val_f2_2.x = (val_f2_2.x - s_mean) * s_var * scale_f2_2.x + bias_f2_2.x; +// val_f2_2.y = (val_f2_2.y - s_mean) * s_var * scale_f2_2.y + bias_f2_2.y; +// val_f2_3.x = (val_f2_3.x - s_mean) * s_var * scale_f2_3.x + bias_f2_3.x; +// val_f2_3.y = (val_f2_3.y - s_mean) * s_var * scale_f2_3.y + bias_f2_3.y; +// val_h2[i] = __float22half2_rn(val_f2); +// val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_h2_2[i] = __float22half2_rn(val_f2_2); +// val_h2_3[i] = __float22half2_rn(val_f2_3); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// output_f4[idx+2] = val_f4_2; +// output_f4[idx+3] = val_f4_3; +// } +// } + +template <> +void launch_layer_norm(float *ln_res, float *vars, float *means, + const float *inp, const float *scale, + const float *bias, int batch_size, int hidden_dim, + cudaStream_t stream) { + if (hidden_dim % 4 != 0) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, + const __half *inp, const __half *scale, + const __half *bias, int batch_size, + int hidden_dim, cudaStream_t stream) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<__half><<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); + // if (hidden_dim % 8 != 0) { + // throw std::runtime_error("violate hidden_dim % 8 = 0"); + // } + // hidden_dim >>= 3; + + // if (hidden_dim * 8 < 8192) { + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm<__half><<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { + // hidden_dim >>= 1; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x2<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { + // hidden_dim >>= 2; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x4<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else { + // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + // } +} + +/** +@brief: ker_ln_bw_dgamma_dbetta +Layer norm backword kernel, compute the gradient of gamma and betta. +dbetta = sum(dout, dim=0) +dgamma = sum(xhat * dout, dim=0) +xhat = (input - mean) * rsqrt(var) or + (output - betta) / gamma + + +@thread +gridDim.x = hidden_size / 32 +blockDim.x = 32 +blockDim.y = 32 + +@param +gamma_grad: [hidden_size], gradient of gamma +betta_grad: [hidden_size], gradient of betta +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat, maybe nullptr +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat, maybe nullptr +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +(gamma && betta) ^ (vars && means) should be true +*/ +template +__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, + const T *out_grad, const T *inp_or_out, + const T *gamma, const T *betta, + const T *vars, const T *means, int rows, + int width) { + __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + // Loop across inp height + float dbetta = 0; + float dgamma = 0; + float dout, val; + if (idx < width) { + if (means == nullptr) { + float vbetta = (float)betta[idx]; + float vgamma = (float)gamma[idx]; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is output + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - vbetta) / add_eps(vgamma) * dout); + offset += y_stride; + } + } else { + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is input + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - (float)means[r]) * + rsqrtf((float)vars[r] + LN_EPSILON) * dout); + offset += y_stride; + } + } + } + + // Sum the shared buffer. + betta_buffer[threadIdx.x][threadIdx.y] = dbetta; + gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; + __syncthreads(); + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (threadIdx.x == 0 && idx < width) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/** +@brief: ker_ln_bw_dinp +Layer norm backword kernel, compute the gradient of input. +dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) + * rsqrt(var) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dxhat = dout * gamma + + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, + usually appear in pre-layer-norm for transformer layer, maybe nullptr +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat and dxhat +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat and dinp +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +*/ +template +__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, + const T *gamma, const T *betta, const T *vars, + const T *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + float4 dxhat, xhat; + float var_rsqrt; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + dxhat = ((const float4 *)out_grad)[offset]; + float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; + dxhat.x *= vgamma.x; + dxhat.y *= vgamma.y; + dxhat.z *= vgamma.z; + dxhat.w *= vgamma.w; + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + xhat = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); + xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); + xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); + xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; + xhat.x = (xhat.x - fmean) * var_rsqrt; + xhat.y = (xhat.y - fmean) * var_rsqrt; + xhat.z = (xhat.z - fmean) * var_rsqrt; + xhat.w = (xhat.w - fmean) * var_rsqrt; + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + float reduce_val[2] = {0.f, 0.f}; + if (threadIdx.x < hidden_dim) { + reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; + reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + + dxhat.w * xhat.w; + } + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + dxhat.x += dresidual.x; + dxhat.y += dresidual.y; + dxhat.z += dresidual.z; + dxhat.w += dresidual.w; + } + ((float4 *)inp_grad)[offset] = dxhat; +} + +template <> +__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + + float2 dxhat[4], xhat[4]; + float var_rsqrt; + float4 vtmp; + __half2 *tmp_h2; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vbetta = __half22float2(betta_h2[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; +} + +__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float var_rsqrt; + float4 vtmp, vtmp_1; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 2; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; +} + +__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float2 dxhat_2[4], xhat_2[4]; + float2 dxhat_3[4], xhat_3[4]; + float var_rsqrt; + float4 vtmp, vtmp_1, vtmp_2, vtmp_3; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + __half2 *tmp_h2_2; + __half2 *tmp_h2_3; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + vtmp_2 = ((const float4 *)out_grad)[offset + 2]; + vtmp_3 = ((const float4 *)out_grad)[offset + 3]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); + tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; + float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; + float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); + __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); + __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vdout_2 = __half22float2(tmp_h2_2[i]); + float2 vdout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + dxhat_2[i].x = vdout_2.x * vgamma_2.x; + dxhat_2[i].y = vdout_2.y * vgamma_2.y; + dxhat_3[i].x = vdout_3.x * vgamma_3.x; + dxhat_3[i].y = vdout_3.y * vgamma_3.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + dxhat_2[i].x + + dxhat_2[i].y + dxhat_3[i].x + dxhat_3[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; + vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; + float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; + float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); + __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); + __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vout_2 = __half22float2(tmp_h2_2[i]); + float2 vout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + float2 vbetta_2 = __half22float2(betta_h2_2[i]); + float2 vbetta_3 = __half22float2(betta_h2_3[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); + xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); + xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + float2 vinp_2 = __half22float2(tmp_h2_2[i]); + float2 vinp_3 = __half22float2(tmp_h2_3[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; + xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; + xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1]; + float4 dresidual_2 = ((const float4 *)residual_grad)[offset+2]; + float4 dresidual_3 = ((const float4 *)residual_grad)[offset+3]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); + __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); + __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_2[2 * i])); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_3[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; + ((float4 *)inp_grad)[offset + 2] = vtmp_2; + ((float4 *)inp_grad)[offset + 3] = vtmp_3; +} + +/** +Layer norm backword, + compute the gradient of gamma, betta and input. +dbetta = sum(dout, dim=0) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dgamma = sum(xhat * dout, dim=0) +dxhat = dout * gamma +dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) + * rsqrt(var) + +residual_grad, means, betta can be nullptr. +residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln +means and betta are only used to compute xhat, + (means == nullptr) ^ (betta == nullptr) should be true +*/ +template <> +void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, + const float *out_grad, const float *residual_grad, + const float *inp_or_out, const float *gamma, + const float *betta, const float *vars, + const float *means, int batch, int hidden_dim, + cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 4 != 0 || hidden_dim > 4096) { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); +} + +template <> +void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, + __half *inp_grad, const __half *out_grad, + const __half *residual_grad, const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, int batch, + int hidden_dim, cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<__half><<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 8 != 0) { + throw std::runtime_error("hidden_dim % 8 != 0"); + } + hidden_dim >>= 3; + + if (hidden_dim * 8 <= 8192) { + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); + } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { + hidden_dim >>= 1; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x2<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); + } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x4<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); + } else { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + } +} + diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu new file mode 100644 index 000000000..86579201b --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu @@ -0,0 +1,366 @@ +#include + +#include +#include + +#include "block_reduce.h" +#include "kernels.h" + +#include + +namespace cg = cooperative_groups; +const float EPSILON = 1e-8f; + +/** +@brief: softmax_kernel +Softmax forward kernel for + enc-self-attn, dec-self-attn, encdec-attn + +@thread +gridDim.x = dynamic +gridDim.y = batch_size +gridDim.z = nhead +blockDim.x = from_len + +@param +inp: [batch_size, nhead, from_len, to_len], softmax input. +attn_mask: [batch_size, to_len], padding tokens are -inf, + non padding tokens are 0. + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template +__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // block reduce max + blockReduce(l_max); + // write shared + __shared__ float s_max[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_max[i] = l_max[i]; + } + } + __syncthreads(); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - s_max[i]); + l_sum[i] += val[i][j]; + } + } + // block reduce sum + blockReduce(l_sum); + // write shared + __shared__ float s_sum[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + } + } + __syncthreads(); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * s_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +template +__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // warp reduce max + warpReduce(l_max); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - l_max[i]); + l_sum[i] += val[i][j]; + } + } + // warp reduce sum + warpReduce(l_sum); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * l_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +/* + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template <> +void launch_attn_softmax(float *inp, const float *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 16; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 32; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 64; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +template <> +void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<__half, 32, 1><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<__half, 32, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 8; + ker_attn_softmax<__half, 64, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 16; + ker_attn_softmax<__half, 128, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 32; + ker_attn_softmax<__half, 256, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +/** +@brief: ker_attn_softmax_bw +Softmax backward in self attention. + +@thread +gridDim.x = batch_size * nhead * seq_len / warps_per_block +blockDim.x = WARP_SIZE +blockDim.y = warps_per_block + +@param +grad: [batch_size, nhead, seq_len, seq_len], output grad. +output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. +*/ +template +__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + inp += offset; + + T grad_reg[ITERATIONS]; + T inp_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + inp_reg[i] = inp[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)inp_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); + } +} + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream) { + const int warps_per_block = 4; + // rows = batch_size * nhead * from_len + dim3 grid_dim(rows / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (softmax_len <= 32) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 64) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 128) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 256) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 384) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 512) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 768) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 1024) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 2048) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else + throw std::runtime_error( + std::string( + "Special sequence length found in softmax backward, seq_len: ") + + std::to_string(softmax_len)); +} + +template void launch_attn_softmax_bw<__half>(__half *out_grad, + const __half *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); +template void launch_attn_softmax_bw(float *out_grad, + const float *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu new file mode 100644 index 000000000..d389d57e1 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu @@ -0,0 +1,314 @@ +#include +#include +#include + +#include "kernels.h" + +using namespace cub; + +/** +@brief: transform_0213 +Split the attention heads and reshape input +during backward progress of encoder self-attention + +@thread +gridDim.x = batch_size +gridDim.y = seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +input: [batch_size, seq_len, hidden_dim] +output: [batch_size, nhead, seq_len, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +*/ + +template +__global__ void transform_0213(T *output, const T *input, int hidden_dim, + int head_dim); + +template <> +__global__ void transform_0213(float *output, const float *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +template <> +__global__ void transform_0213<__half>(__half *output, const __half *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +// [b, s, h] -> [b, nh, s, ad] +template <> +void launch_transform_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213 + <<>>(output, input, hidden_dim, head_dim); +} + +template <> +void launch_transform_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213<__half> + <<>>(output, input, hidden_dim, head_dim); +} + +/** +@brief: bias_add_transform_20314 +Add bias to input, transform from +[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] + +@thread +gridDim.x = dim_0 +gridDim.y = dim_1 +gridDim.z = dim_2 +blockDim.x = min(dim_3 * dim_4, MAX_THREADS) + +@param +input: [dim_0, dim_1, dim_2, dim_3, dim_4] +bias: [dim_2, dim_3, dim_4] +output: [dim_2, dim_0, dim_3, dim_1, dim_4] +*/ +template +__global__ void bias_add_transform_20314(T *output, const T *input, + const T *bias, int dim_3, int dim_4); + +template <> +__global__ void bias_add_transform_20314(float *output, + const float *input, + const float *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + vres4.x = vqkv4.x + vbias4.x; + vres4.y = vqkv4.y + vbias4.y; + vres4.z = vqkv4.z + vbias4.z; + vres4.w = vqkv4.w + vbias4.w; + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +template <> +__global__ void bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); + __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); + __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); + h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); + h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); + h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template <> +void launch_bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 2; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314 + <<>>(output, input, bias, dim_3, dim_4); +} + +template <> +void launch_bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 3; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314<__half> + <<>>(output, input, bias, dim_3, dim_4); +} + +/** +@brief: transform4d_0213 +Reshape the input matrix to merge the heads + +@thread +gridDim.x = (num_all + max_block_thread - 1) / max_block_thread +blockDim.x = max_block_thread + +@param +input: [trans_count, batch_size, nhead, seq_len, head_dim] +output: [batch_size, seq_len, trans_count, nhead, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +trans_count: 1 or 3, the count of matrice need to be transformed +*/ +template +__global__ void transform4d_0213(T *output, const T *input, int batch_size, + int seq_len, int trans_count, int nhead, + int head_dim, int num_all) { + int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset >= num_all) { + return; + } + int trans_id, batch_id, head_id, token_id, dim_id; + decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, + &batch_id, &head_id, &token_id, &dim_id); + // [b, s, tc, nh, ad] + int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, + seq_len, trans_count, nhead, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + res4[trg_offset] = input4[offset]; +} + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template <> +void launch_transform4d_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} + +template <> +void launch_transform4d_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, + int hidden_dim, int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<__half><<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp new file mode 100644 index 000000000..c42d91d36 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -0,0 +1,185 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include +#include +#include +#include "compat.h" + +namespace { + +void compute_n1_n2( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2) { + int idiff = input.ndimension() - normalized_shape.size(); + n2 = 1; + for (int i = 0; i < (int)normalized_shape.size(); ++i) { + assert( input.sizes()[i+idiff] == normalized_shape[i] ); + n2 *= normalized_shape[i]; + } + n1 = 1; + for (int i = 0; i < idiff; ++i) { + n1 *= input.sizes()[i]; + } +} + +void check_args( + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta + ) +{ + TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); + TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); +} + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + int& n1, + int& n2 + ) +{ + int64_t normalized_ndim = normalized_shape.size(); + + if (normalized_ndim < 1) { + std::stringstream ss; + ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " + << "containing at least one element, but got normalized_shape=" + << normalized_shape; + throw std::runtime_error(ss.str()); + } + + auto input_shape = input.sizes(); + auto input_ndim = input.dim(); + + if (input_ndim < normalized_ndim || + !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { + std::stringstream ss; + ss << "Given normalized_shape=" << normalized_shape + << ", expected input with shape [*"; + for (auto size : normalized_shape) { + ss << ", " << size; + } + ss << "], but got input of size" << input_shape; + throw std::runtime_error(ss.str()); + } + + compute_n1_n2(input,normalized_shape,n1,n2); +} + + +void check_args( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + int& n1, + int& n2 + ) +{ + check_args(input,normalized_shape,n1,n2); + check_args(normalized_shape,gamma,beta); +} +} + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon); + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector layer_norm_affine( + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor output = at::empty_like( + input, gamma.options().dtype(gamma.scalar_type())); + at::Tensor mean = at::empty( + {n1}, input.options().dtype(at::ScalarType::Float)); + at::Tensor invvar = at::empty_like(mean); + + cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon); + + return {output, mean, invvar}; + +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + at::IntArrayRef normalized_shape, + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta + ); + +std::vector layer_norm_gradient_affine( + at::Tensor dout, + at::Tensor mean, + at::Tensor invvar, + at::Tensor input, + at::IntArrayRef normalized_shape, + at::Tensor gamma, + at::Tensor beta, + double epsilon) { + + CHECK_INPUT(dout); + CHECK_INPUT(mean); + CHECK_INPUT(invvar); + CHECK_INPUT(input); + CHECK_INPUT(gamma); + CHECK_INPUT(beta); + int n1, n2; + check_args(input, normalized_shape, gamma, beta, n1, n2); + + at::Tensor grad_input = at::empty_like(input); + at::Tensor grad_gamma = at::empty_like(gamma); + at::Tensor grad_beta = at::empty_like(beta); + + cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2, + normalized_shape, &gamma, &beta, epsilon, + &grad_input, &grad_gamma, &grad_beta); + + return {grad_input, grad_gamma, grad_beta}; + +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_affine", &layer_norm_affine, + "LayerNorm forward (CUDA)"); + m.def("backward_affine", &layer_norm_gradient_affine, + "LayerNorm backward (CUDA)"); +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu new file mode 100644 index 000000000..dc52f8019 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu @@ -0,0 +1,813 @@ +/*This code from NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/cuda/CUDAContext.h" +#include + +#include +#include + +#include "type_shim.h" + +template __device__ +void cuWelfordOnlineSum( + const U curr, + U& mu, + U& sigma2, + U& count) +{ + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template __device__ +void cuChanOnlineSum( + const U muB, + const U sigma2B, + const U countB, + U& mu, + U& sigma2, + U& count) +{ + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA*mu + nB*muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ +void cuWelfordMuSigma2( + const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu= U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l+k]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2*threadIdx.y]; + U sigma2B = ubuf[2*threadIdx.y+1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/U(n2), 0); + } + } +} + +template<> __device__ +void cuWelfordMuSigma2( + const at::Half* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu= float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1*n2; + int l = 8*thrx; + if ((((size_t)lvals)&3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l+7 < n2; l+=8*numx) { + for (int k = 0; k < 8; k+=2) { + float2 curr = __half22float2(*((__half2*)(lvals+l+k))); + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1< 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2*threadIdx.y]; + float sigma2B = ubuf[2*threadIdx.y+1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/float(n2), 0); + } + } +} + +template U rsqrt(U v) { + return U(1) / sqrt(v); +} +template<> float rsqrt(float v) { + return rsqrtf(v); +} +template<> double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory +{ + __device__ float *getPointer() + { + extern __shared__ float s_float[]; + return s_float; + } +}; + +} + +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta + ) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu,sigma2; + cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); + const T* lvals = vals + i1*n2; + V* ovals = output_vals + i1*n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; + } + } else { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + } +} + +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template __global__ +void cuComputePartGradGammaBeta( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta) +{ + const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta) +{ + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; + sum_beta += part_grad_beta_ptr[warp_offset*n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx+nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx+nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template __global__ +void cuComputeGradInput( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input) +{ + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1*n2; + const V* k_dout = dout + i1*n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss * gamma[l+k]; + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x/2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2*wrt_i] = sum_loss1; + buf[2*wrt_i+1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2*read_i]; + sum_loss2 += buf[2*read_i+1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2*threadIdx.x] = sum_loss1; + buf[2*threadIdx.x+1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y !=0) { + sum_loss1 = buf[2*threadIdx.x]; + sum_loss2 = buf[2*threadIdx.x+1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1*n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + } +} + + + + +template +void HostApplyLayerNorm( + V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const dim3 threads(32,4,1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + cuApplyLayerNorm<<>>( + output, + mean, + invvar, + input, + n1,n2, + U(epsilon), + gamma,beta); +} + + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm( + output->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL); + ) +} + + +template +void HostLayerNormGradient( + const V* dout, + const U* mean, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + const V* beta, + double epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta + ) +{ + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32,4,1); + const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * + (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size,n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR()); + + const dim3 threads3(32,8,1); + const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + part_size, + n1,n2, + grad_gamma, + grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + cuComputeGradInput<<>>( + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + grad_input); +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL); + ) +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp new file mode 100644 index 000000000..63bf633f5 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp @@ -0,0 +1,364 @@ +#include "multihead_attention_1d.h" + +#include +#include + +#include +#include + +#include "context.h" +#include "kernels.h" + +template +MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, int max_seq_len, + int hidden_size, int num_heads, + float attn_prob_dropout_ratio, + float hidden_output_dropout_ratio, + bool pre_or_postLayerNorm) + : _layer_id(layer_id), + _max_batch_tokens(max_batch_tokens), + _max_seq_len(max_seq_len), + _hidden_size(hidden_size), + _heads(num_heads), + _training(true), + _pre_or_postLayerNorm(pre_or_postLayerNorm), + _qkv_linear(typename FeedForward::Config(3 * hidden_size, hidden_size)), + _attn_out_linear(typename FeedForward::Config(hidden_size, hidden_size)), + _attn_ln(typename Normalize_Layer::Config(hidden_size, false), _max_batch_tokens), + _softmax(typename Softmax::Config(num_heads)), + _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), + _max_batch_tokens * _heads * _max_seq_len), + _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), + _max_batch_tokens * _hidden_size), + _attn_scores(typename StridedBatchGemm::Config((T(1.0) / T(sqrt(_hidden_size / _heads))), + T(0.0), CUBLAS_OP_T, CUBLAS_OP_N)), + _attn_context( + typename StridedBatchGemm::Config(T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { + assert(_hidden_size % _heads == 0); +} + +template +MultiHeadAttention::~MultiHeadAttention() { + free_mem_buffer(); +} + +template +void MultiHeadAttention::attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, + T *output_ptr, T *buffer) { + T *q_tf_ptr = _qkv_ptr; + T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; + T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; + + if (_pre_or_postLayerNorm) { + _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, + _stream); + } + const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); + _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, _cublasHandle); + + launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, _batch_size, _seq_len, 3, + _heads / pg_size, _hidden_size / _heads, _stream); + + // attention scores, q*k + _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle); + + // Softmax + Mask + _softmax.reset_size(_heads / pg_size); + _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, _seq_len, _stream, true); + + // attn prob dropout. + _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, _batch_heads * _seq_len * _seq_len, + _stream); + + // attention context, score * v + _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle); + + // [b, nh, s, ad] -> [b, s, nh, ad] + launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, _hidden_size / pg_size, + _heads / pg_size, 1, _stream); + + _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); + _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, output_ptr, _cublasHandle); + + // allreduce + if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { + } else { + auto data_type = torch::kFloat; + if (typeid(T) != typeid(float)) { + data_type = torch::kHalf; + } + auto output_tensor = + torch::from_blob(output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::TensorOptions(torch::kCUDA).dtype(data_type)); + std::vector allreduce_tensors = {output_tensor}; + auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); + work->wait(); + } + + _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, _attn_ob_ptr, + _batch_tokens, _hidden_size, _stream); + if (!_pre_or_postLayerNorm) { + // in-place ln since ln-input will not be used in post-ln mode + _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _stream); + } +} + +template +void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr) { + _stream = Context::Instance().get_stream(); + _cublasHandle = Context::Instance().get_cublashandle(); + T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim + + attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer); +} + +template +void MultiHeadAttention::attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, + const T *grad_output_ptr, T *grad_input_ptr, T *buffer) { + cudaStream_t streams[2] = {_stream, _stream}; + + const T *q_tf_ptr = _qkv_ptr; + const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; + const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; + // batch_dim = batch_size * seq_len * hidden_size + // buffer size: batch_dim * 3 + max(batch_dim * 3, + // batch_size * head_num * seq_len * seq_len) + T *grad_residual_ptr = buffer; + buffer += _batch_dim; + + T *grad_input_buf_ptr = buffer; // batch_dim + T *grad_qkv_5d_ptr = buffer; // batch_dim * 3 + buffer += 3 * _batch_dim / pg_size; + + T *grad_qkv_4d_ptr = buffer; // batch_dim * 3 + T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len + // buffer += max(3 * _batch_dim, + // batch_size * head_num * seq_len * seq_len); + + if (_pre_or_postLayerNorm) { + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_output_ptr, + _batch_tokens, _hidden_size, _stream); + } else { + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, grad_output_ptr, + nullptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_residual_ptr, + _batch_tokens, _hidden_size, _stream); + } + + // bw of output project + _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); + _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, _attn_ow_ptr, + _grad_attn_ow_ptr, _grad_attn_ob_ptr, _cublasHandle, _stream, + grad_input_buf_ptr, nullptr, false); + launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, _seq_len, + _hidden_size / pg_size, _heads / pg_size, _stream); + + // bw of score * v + _attn_context.Backward(_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, + grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); + + _attn_prob_dropout.d_dropout(grad_softmax_ptr, _batch_heads * _seq_len * _seq_len, _stream); + + _softmax.reset_size(_heads / pg_size); + _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, _seq_len, _stream); + + // bw of q * k + _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle, + grad_qkv_5d_ptr + _batch_dim / pg_size, grad_qkv_5d_ptr); + + // [3, b, nh, s, ad] -> [b, s, 3, h] + launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, _seq_len, + _hidden_size / pg_size, _heads / pg_size, 3, _stream); + + const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); + _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, _attn_qkvw_ptr, + _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, _cublasHandle, _stream, + grad_input_buf_ptr, nullptr, true); + + // allreduce + if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { + } else { + auto data_type = torch::kFloat; + if (typeid(T) != typeid(float)) { + data_type = torch::kHalf; + } + auto grad_input_tensor = + torch::from_blob(grad_input_buf_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::TensorOptions(torch::kCUDA).dtype(data_type)); + std::vector allreduce_tensors = {grad_input_tensor}; + auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); + work->wait(); + } + + if (_pre_or_postLayerNorm) { + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, grad_input_buf_ptr, + grad_output_ptr, gemmQKV_inp_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, + streams); + } else { + // FIXME later + launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, _batch_size, + _seq_len, _hidden_size, _stream); + } +} + +template +void MultiHeadAttention::Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, + const T *input_mask_ptr, T *grad_input_ptr) { + _stream = Context::Instance().get_stream(); + _cublasHandle = Context::Instance().get_cublashandle(); + T *buffer = _shared_mem_ptr; + + /* + buffer size needed by attn bw: + 4 * _batch_dim + max(3 * _batch_dim, + _batch_size * _head_num * _seq_len * _seq_len); + */ + attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, grad_input_ptr, buffer); +} + +template +void MultiHeadAttention::SetTrainingMode(bool training) { + // Dropout will be skipped when not in training model. + _attn_prob_dropout.SetTrainingMode(training); + _attn_dropout.SetTrainingMode(training); +} + +template +T *MultiHeadAttention::_shared_mem_ptr = nullptr; + +template class MultiHeadAttention; +template class MultiHeadAttention<__half>; + +// x is torch::Tensor +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +static std::unordered_map> s_multihead_attention; + +template +int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_len, int hidden_dim, + int num_heads, float attn_prob_dropout_ratio, + float hidden_dropout_ratio, bool pre_or_postLayerNorm, + c10::intrusive_ptr pg_) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + Context::Instance().set_stream(stream); + auto layer = std::make_shared>( + layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, attn_prob_dropout_ratio, + hidden_dropout_ratio, pre_or_postLayerNorm); + + layer->SetPG(pg_); + + s_multihead_attention[layer_id] = layer; + + std::string dtype = (std::is_same::value) ? "half" : "float"; + + return 0; +} + +template +std::vector multihead_attention_fw(int layer_id, const torch::Tensor &input, + const torch::Tensor &input_mask, + const torch::Tensor &in_proj_weight, + const torch::Tensor &in_proj_bias, + const torch::Tensor &out_proj_weight, + const torch::Tensor &out_proj_bias, + const torch::Tensor &norm_weight, + const torch::Tensor &norm_bias, + bool training_mode, bool prelayernorm) { + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + + const T *input_ptr = (const T *)input.data_ptr(); + const T *input_mask_ptr = (const T *)input_mask.data_ptr(); + + auto output = torch::empty_like(input); + T *out_ptr = (T *)output.data_ptr(); + + std::shared_ptr> layer = + std::static_pointer_cast>(s_multihead_attention[layer_id]); + layer->set_cur_batch_shape(input.size(0), input.size(1)); + layer->SetTrainingMode(training_mode); + + layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr(); + layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr(); + layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr(); + layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr(); + layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr(); + layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr(); + + layer->Forward(input_ptr, input_mask_ptr, out_ptr); + + return {output}; +} + +template +std::vector multihead_attention_bw(int layer_id, + const torch::Tensor &grad_dec_output, + const torch::Tensor &output, + const torch::Tensor &input, + const torch::Tensor &input_mask, + const torch::Tensor &in_proj_weight, + const torch::Tensor &in_proj_bias, + const torch::Tensor &out_proj_weight, + const torch::Tensor &out_proj_bias, + const torch::Tensor &norm_weight, + const torch::Tensor &norm_bias) { + auto g_output = grad_dec_output.contiguous(); + CHECK_INPUT(g_output); + CHECK_INPUT(output); + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + + auto grad_input = torch::empty_like(input); + auto grad_in_proj_weight = torch::empty_like(in_proj_weight); + auto grad_in_proj_bias = torch::empty_like(in_proj_bias); + auto grad_out_proj_weight = torch::empty_like(out_proj_weight); + auto grad_out_proj_bias = torch::empty_like(out_proj_bias); + auto grad_norm_weight = torch::empty_like(norm_weight); + auto grad_norm_bias = torch::empty_like(norm_bias); + + // inputs. + const T *grad_dec_output_ptr = (const T *)g_output.data_ptr(); + const T *input_ptr = (const T *)input.data_ptr(); + const T *output_ptr = (const T *)output.data_ptr(); + const T *input_mask_ptr = (const T *)input_mask.data_ptr(); + + // outputs. + T *grad_input_ptr = (T *)grad_input.data_ptr(); + + std::shared_ptr> layer = + std::static_pointer_cast>(s_multihead_attention[layer_id]); + layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); + + layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); + layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr(); + layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr(); + layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr(); + layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); + layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); + + layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, grad_input_ptr); + + return {grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, + grad_out_proj_bias, grad_norm_weight, grad_norm_bias}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("multihead_attention_fw_fp32", &multihead_attention_fw, + "Multi-head Attention forward with fp32 (CUDA)"); + m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>, + "Multi-head Attention forward with fp16 (CUDA)"); + m.def("multihead_attention_bw_fp32", &multihead_attention_bw, + "Multi-head Attention backward with fp32 (CUDA)"); + m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>, + "Multi-head Attention backward with fp16 (CUDA)"); + m.def("create_multihead_attention_fp32", &create_multihead_attention, + "Create Multi-head Attention with fp32 (CUDA)"); + m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>, + "Create Multi-head Attention with fp16 (CUDA)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h new file mode 100644 index 000000000..1dd84773a --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h @@ -0,0 +1,153 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include "cuda_util.h" +#include "dropout.h" +#include "feed_forward.h" +#include "normalize_layer.h" +#include "softmax.h" +#include "strided_batch_gemm.h" + +template +class MultiHeadAttention { + public: + MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size, + int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio, + bool pre_or_postLayerNorm); + + virtual ~MultiHeadAttention(); + + void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); + + void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, + const T *input_mask_ptr, T *grad_input_ptr); + + void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer); + + void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, + const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer); + + void set_cur_batch_shape(int batch_size, int seq_len) { + _batch_size = batch_size; + _seq_len = seq_len; + _batch_tokens = batch_size * seq_len; + _batch_heads = batch_size * _heads / pg_size; + _batch_dim = _batch_tokens * _hidden_size; + _attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads); + _attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len); + } + + void SetTrainingMode(bool training); + inline bool IsTrainingMode() const { return _training; } + + void SetPG(c10::intrusive_ptr pg_) { + pg = pg_; + pg_size = 1; + if (pg != c10::detail::UniqueVoidPtr()) { + pg_size = pg->getSize(); + } + allocate_mem_buffer(); + } + + // weights ptr + const T *_attn_qkvw_ptr; + const T *_attn_qkvb_ptr; + const T *_attn_ow_ptr; + const T *_attn_ob_ptr; + const T *_attn_nw_ptr; + const T *_attn_nb_ptr; + + // grads ptr + T *_grad_attn_qkvw_ptr; + T *_grad_attn_qkvb_ptr; + T *_grad_attn_ow_ptr; + T *_grad_attn_ob_ptr; + T *_grad_attn_nw_ptr; + T *_grad_attn_nb_ptr; + + private: + void allocate_mem_buffer() { + // allocate local gpu memory + if (_pre_or_postLayerNorm) { + _gemmQKV_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); + } else { + _gemmQKV_inp_ptr = nullptr; + } + + _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); + _soft_out_ptr = cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _ctx_bufB_ptr = cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); + _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); + + // buffer size needed by attn bw + size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size + + std::max(3 * _max_batch_tokens * _hidden_size / pg_size, + _max_batch_tokens * _heads / pg_size * _max_seq_len); + + if (!_shared_mem_ptr) { + cuda_free(_shared_mem_ptr); + _shared_mem_ptr = cuda_malloc(smem_size); + } + } + + void free_mem_buffer() { + // free local gpu memory + cuda_free(_gemmQKV_inp_ptr); + cuda_free(_qkv_ptr); + cuda_free(_soft_out_ptr); + cuda_free(_ctx_bufB_ptr); + cuda_free(_attn_o_inp_ptr); + + // free shared gpu memory between layers + cuda_free(_shared_mem_ptr); + _shared_mem_ptr = nullptr; + } + + // const parameter between batch + const size_t _layer_id; + const size_t _hidden_size; + const size_t _heads; + const size_t _max_batch_tokens; + const size_t _max_seq_len; + const bool _pre_or_postLayerNorm; + // dynamic parameter between batch + size_t _batch_size; + size_t _seq_len; + size_t _batch_tokens; + size_t _batch_heads; + size_t _batch_dim; + bool _training; + + cublasHandle_t _cublasHandle; + cudaStream_t _stream; + + // layers + FeedForward _qkv_linear; + FeedForward _attn_out_linear; + Normalize_Layer _attn_ln; + Softmax _softmax; + Dropout _attn_prob_dropout; + Dropout _attn_dropout; + StridedBatchGemm _attn_scores; + StridedBatchGemm _attn_context; + + // local GPU memory + T *_gemmQKV_inp_ptr; + T *_qkv_ptr; + T *_soft_out_ptr; + T *_ctx_bufB_ptr; + T *_attn_o_inp_ptr; + // shared GPU memory between layer + static T *_shared_mem_ptr; + + c10::intrusive_ptr pg; + int pg_size; +}; \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp new file mode 100644 index 000000000..4ae3c853c --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -0,0 +1,84 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h new file mode 100644 index 000000000..1583030b8 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h @@ -0,0 +1,492 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu new file mode 100644 index 000000000..d2370e9f3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu @@ -0,0 +1,104 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 000000000..590ea7b3f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,59 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h new file mode 100644 index 000000000..3af487f9d --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,500 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 000000000..9dbb63476 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,85 @@ +/*This code from NVIDIA Megatron: + * with minor changes. */ + +#include +#include +#include +#include +#include +#include +#include +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + return softmax_results; +} + + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h new file mode 100644 index 000000000..845615feb --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -0,0 +1,73 @@ +#include +#include "compat.h" + + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py new file mode 100644 index 000000000..22b6efa01 --- /dev/null +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -0,0 +1,69 @@ +"""This code is from NVIDIA apex: + https://github.com/NVIDIA/apex + with some changes. """ + +import numbers +import torch +from torch.nn.parameter import Parameter +from torch.nn import init +import importlib + +global colossal_layer_norm_cuda +colossal_layer_norm_cuda = None + + +class FusedLayerNormAffineFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, weight, bias, normalized_shape, eps): + + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input_ = input.contiguous() + weight_ = weight.contiguous() + bias_ = bias.contiguous() + output, mean, invvar = colossal_layer_norm_cuda.forward_affine( + input_, ctx.normalized_shape, weight_, bias_, ctx.eps) + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + + return output + + @staticmethod + def backward(ctx, grad_output): + + input_, weight_, bias_, mean, invvar = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input, grad_weight, grad_bias \ + = colossal_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) + + return grad_input, grad_weight, grad_bias, None, None + + +class MixedFusedLayerNorm(torch.nn.Module): + + def __init__(self, normalized_shape, eps=1e-5): + super(MixedFusedLayerNorm, self).__init__() + + global colossal_layer_norm_cuda + colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda") + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + + def reset_parameters(self): + + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input): + + return FusedLayerNormAffineFunction.apply(input, self.weight, self.bias, + self.normalized_shape, self.eps) diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py new file mode 100644 index 000000000..52e1c0bcf --- /dev/null +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -0,0 +1,270 @@ +import math +import importlib +from dataclasses import dataclass + +import torch +from torch import nn +from torch.autograd import Function + + +def check_config(config): + if config.hidden_size % config.nhead != 0: + raise Exception(f"hidden_size % nhead != 0") + + factor = 8 if config.fp16 else 4 + upbound = factor * 1024 * 4 + if config.hidden_size > upbound: + # as required by ln backward kernel currently + raise Exception(f"hidden_size > {upbound}") + + head_dim = config.hidden_size // config.nhead + if head_dim % factor != 0: + # as required by reshape kernel + raise Exception(f"head_dim({head_dim}) % {factor} != 0") + + +def calc_offset(sizes): + offsets = [0] + tmp = 0 + for x in sizes: + tmp += x + offsets.append(tmp) + return offsets + + +colossal_multihead_attention = None + +@dataclass +class Config: + max_batch_tokens: int # max batch token numbers + max_seq_len: int # max sequence length + hidden_size: int # size of transformer hidden layers + nhead: int # number of heads in attention + attn_prob_dropout_ratio: float # attention score dropout ratio + hidden_dropout_ratio: float # dropout ration before residual + norm_first: bool # norm_first + fp16: bool # fp16 presion + + +class MultiHeadAttention1DFunc(Function): + + @staticmethod + def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, + out_proj_bias, norm_weight, norm_bias, config): + cuda_module = colossal_multihead_attention + forward_func = (cuda_module.multihead_attention_fw_fp16 + if config.fp16 else cuda_module.multihead_attention_fw_fp32) + if config.fp16: + input = input.to(torch.half) + input_mask = input_mask.to(torch.half) + + (output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, + out_proj_weight, out_proj_bias, norm_weight, norm_bias, + config.training, config.norm_first) + + if config.is_grad_enabled and config.training: + ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, + out_proj_weight, out_proj_bias, norm_weight, norm_bias) + ctx.config = config + return output + + @staticmethod + def backward(ctx, grad_output): + assert ctx.config.training + + cuda_module = colossal_multihead_attention + backward_func = (cuda_module.multihead_attention_bw_fp16 + if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32) + + output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, \ + out_proj_bias, norm_weight, norm_bias = ctx.saved_tensors + + grad_input = None + grad_in_proj_weight = None + grad_in_proj_bias = None + grad_out_proj_weight = None + grad_out_proj_bias = None + grad_norm_weight = None + grad_norm_bias = None + + if ctx.config.fp16: + grad_output = grad_output.to(torch.half) + output = output.to(torch.half) + input = input.to(torch.half) + input_mask = input_mask.to(torch.half) + grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \ + grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func( + ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, \ + in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) + + return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, + grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None) + + +class MultiHeadAttention(nn.Module): + """Initialize the MultiHeadAttention. + + Static variable: + layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated, + e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23. + Arguments: + hidden_size: Total dimension of hidden_size. + nhead: Number of parallel attention heads. + batch_size: Batch Size for one foward + max_seq_len: Max length of input sequence + dropout: Dropout probability + norm_first: perform LayerNorms before attention + """ + + layer_id = 0 + + def __init__(self, + hidden_size, + nhead, + batch_size, + max_seq_len, + dropout=0.0, + norm_first=False, + fp16=True, + pg=None): + super(MultiHeadAttention, self).__init__() + + self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, + dropout, norm_first, fp16) + check_config(self.config) + self.pg = pg + self.pg_size = 1 + if self.pg: + self.pg_size = pg.size() + self.config.layer_id = MultiHeadAttention.layer_id + MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1 + + # Load cuda modules if needed + global colossal_multihead_attention + if colossal_multihead_attention is None: + colossal_multihead_attention = importlib.import_module("colossal_multihead_attention") + + # create the layer in cuda kernels. + cuda_module = colossal_multihead_attention + create_layer_func = (cuda_module.create_multihead_attention_fp16 + if self.config.fp16 else cuda_module.create_multihead_attention_fp32) + + create_layer_func( + self.config.layer_id, + self.config.max_batch_tokens, + self.config.max_seq_len, + self.config.hidden_size, + self.config.nhead, + self.config.attn_prob_dropout_ratio, + self.config.hidden_dropout_ratio, + self.config.norm_first, + self.pg, + ) + + hs = self.config.hidden_size + + self.precision = torch.float32 + if self.config.fp16: + self.precision = torch.half + + self.hs_per_rank = int(hs / self.pg_size) + + self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs)) + self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank)) + self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank)) + self.out_proj_bias = nn.Parameter(torch.Tensor(hs)) + self.norm_weight = nn.Parameter(torch.Tensor(hs)) + self.norm_bias = nn.Parameter(torch.Tensor(hs)) + + self.reset_parameters() + torch.cuda.empty_cache() + + def calc_bound(self, w): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) + bound = 1.0 / math.sqrt(fan_in) + return bound + + def reset_parameters(self): + hs = self.config.hidden_size + + nn.init.zeros_(self.out_proj_bias) + + nn.init.ones_(self.norm_weight) + nn.init.zeros_(self.norm_bias) + + if self.pg_size > 1: + rank_in_pg = torch.distributed.get_rank(self.pg) + attn_qkvw_global = torch.empty(hs * 3, hs) + attn_qkvb_global = torch.empty(hs * 3) + nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0)) + bound = self.calc_bound(attn_qkvw_global) + nn.init.uniform_(attn_qkvb_global, -bound, bound) + + attn_qkvw_global = attn_qkvw_global.cuda() + attn_qkvb_global = attn_qkvb_global.cuda() + torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg) + torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg) + attn_qkvw_global = attn_qkvw_global.cpu() + attn_qkvb_global = attn_qkvb_global.cpu() + + with torch.no_grad(): + self.in_proj_weight.copy_( + attn_qkvw_global.view(3, hs, hs)[:, + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size), :]) + self.in_proj_bias.copy_( + attn_qkvb_global.view(3, hs)[:, + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size)]) + + attn_ow_global = torch.empty(hs, hs) + nn.init.xavier_uniform_(attn_ow_global, 1.0) + attn_ow_global = attn_ow_global.cuda() + torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) + attn_ow_global = attn_ow_global.cpu() + with torch.no_grad(): + self.out_proj_weight.copy_(attn_ow_global[:, + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size)]) + + else: + attn_qkvw = self.in_proj_weight.view(-1, hs) + nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0)) + bound = self.calc_bound(attn_qkvw) + nn.init.uniform_(self.in_proj_bias, -bound, bound) + + nn.init.xavier_uniform_(self.out_proj_weight, 1.0) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + destination = torch.nn.Module.state_dict(self, + destination=destination, + prefix=prefix, + keep_vars=keep_vars) + return destination + + def forward(self, hidden_states, encoder_padding_mask): + self.config.training = self.training + self.config.is_grad_enabled = torch.is_grad_enabled() + hidden_states = hidden_states.contiguous() + encoder_padding_mask = ((encoder_padding_mask * -1e8).type_as(hidden_states).contiguous()) + + bs, sl, dim = hidden_states.size() + if bs * sl > self.config.max_batch_tokens: + raise ValueError( + f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") + if sl > self.config.max_seq_len: + raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") + if len(encoder_padding_mask.size()) == 1: + assert bs == 1 and sl == encoder_padding_mask.size(0) + else: + assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) + + output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, + self.in_proj_weight, self.in_proj_bias, + self.out_proj_weight, self.out_proj_bias, + self.norm_weight, self.norm_bias, self.config) + + return output.to(self.precision) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py new file mode 100644 index 000000000..c1388e299 --- /dev/null +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -0,0 +1,184 @@ +"""This code from NVIDIA Megatron + with some changes. """ + +import torch +import torch.nn as nn +import enum + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + import colossal_scaled_upper_triang_masked_softmax + + scale_t = torch.tensor([scale]) + softmax_results = colossal_scaled_upper_triang_masked_softmax.forward( + inputs, scale_t[0] + ) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import colossal_scaled_upper_triang_masked_softmax + + softmax_results, scale_t = ctx.saved_tensors + input_grads = colossal_scaled_upper_triang_masked_softmax.backward( + output_grads, softmax_results, scale_t[0] + ) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + import colossal_scaled_masked_softmax + + scale_t = torch.tensor([scale]) + + softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + import colossal_scaled_masked_softmax + + softmax_results, scale_t = ctx.saved_tensors + + input_grads = colossal_scaled_masked_softmax.backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + assert ( + self.scale is None or softmax_in_fp32 + ), "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import colossal_scaled_masked_softmax + + return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py new file mode 100644 index 000000000..3f0e888bf --- /dev/null +++ b/colossalai/kernel/jit/__init__.py @@ -0,0 +1,3 @@ +from .option import _set_jit_fusion_options + +_set_jit_fusion_options() \ No newline at end of file diff --git a/colossalai/kernel/jit/bias_dropout_add.py b/colossalai/kernel/jit/bias_dropout_add.py new file mode 100644 index 000000000..3687dde79 --- /dev/null +++ b/colossalai/kernel/jit/bias_dropout_add.py @@ -0,0 +1,24 @@ +import torch + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) diff --git a/colossalai/kernel/jit/bias_gelu.py b/colossalai/kernel/jit/bias_gelu.py new file mode 100644 index 000000000..f7a425dd5 --- /dev/null +++ b/colossalai/kernel/jit/bias_gelu.py @@ -0,0 +1,41 @@ +import torch + + +###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) + +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, bias, y): + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff*g + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(bias, input) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, bias, input) + return tmp, tmp + +bias_gelu_impl = GeLUFunction.apply \ No newline at end of file diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py new file mode 100644 index 000000000..06823ad3e --- /dev/null +++ b/colossalai/kernel/jit/option.py @@ -0,0 +1,28 @@ +import torch + +JIT_OPTIONS_SET = False + +def _set_jit_fusion_options(): + """Set PyTorch JIT layer fusion options.""" + global JIT_OPTIONS_SET + if JIT_OPTIONS_SET == False: + # flags required to enable jit fusion kernels + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): + # nvfuser + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + + JIT_OPTIONS_SET = True diff --git a/setup.py b/setup.py index f7684d4da..20ddf3477 100644 --- a/setup.py +++ b/setup.py @@ -131,5 +131,6 @@ setup( description='An integrated large-scale model training system with efficient parallelization techniques', ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension} if ext_modules else {}, + package_data={'colossalai': ['kernel/cuda_native/csrc/*']}, install_requires=install_requires, ) \ No newline at end of file