diff --git a/extensions/csrc/__init__.py b/extensions/csrc/__init__.py index 0eac28d23..e69de29bb 100644 --- a/extensions/csrc/__init__.py +++ b/extensions/csrc/__init__.py @@ -1,11 +0,0 @@ -from .layer_norm import MixedFusedLayerNorm as LayerNorm -from .multihead_attention import MultiHeadAttention -from .scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax - -__all__ = [ - "LayerNorm", - "MultiHeadAttention", - "FusedScaleMaskSoftmax", - "ScaledUpperTriangMaskedSoftmax", - "AttnMaskType", -] diff --git a/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu index b45daea47..f992e6faa 100644 --- a/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu @@ -4,6 +4,10 @@ #include "utils/vec_copy.h" #include "../common/micros.h" +using colossalAI::cuda::utils::copy_vector; +using colossalAI::cuda::utils::get_vec_size; + + template __global__ void context_kv_cache_memcpy_kernel( const scalar_t* __restrict__ key, diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index e0cfbbed7..8eb9fb00f 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -4,6 +4,9 @@ #include "utils/vec_copy.h" #include "../common/micros.h" +using colossalAI::cuda::utils::copy_vector; +using colossalAI::cuda::utils::get_vec_size; + template __global__ void decode_kv_cache_memcpy_kernel( const scalar_t* __restrict__ key, diff --git a/extensions/csrc/cuda/funcs/cast_functor.h b/extensions/csrc/cuda/funcs/cast_functor.h index dbb7195d0..05fffb766 100644 --- a/extensions/csrc/cuda/funcs/cast_functor.h +++ b/extensions/csrc/cuda/funcs/cast_functor.h @@ -30,17 +30,25 @@ struct CastFunctor : public std::unary_function { COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y), DEVICE) + COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half, __float2half(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, + __float2bfloat16(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat162, + __float2bfloat162_rn(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val), + DEVICE) + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, + __bfloat162float(val), DEVICE) + COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE) COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val), DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val), - DEVICE) COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, nv_bfloat162, - __float2bfloat162_rn(val), DEVICE) #undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION } // namespace funcs diff --git a/extensions/csrc/cuda/include/block_reduce.h b/extensions/csrc/cuda/funcs/reduce_function.h similarity index 65% rename from extensions/csrc/cuda/include/block_reduce.h rename to extensions/csrc/cuda/funcs/reduce_function.h index a9bd537f7..da2743e62 100644 --- a/extensions/csrc/cuda/include/block_reduce.h +++ b/extensions/csrc/cuda/funcs/reduce_function.h @@ -8,7 +8,7 @@ namespace colossalAI { namespace cuda { -namespace utils { +namespace funcs { const float kReduceFloatInfNeg = -100000000.f; const float kReduceFloatInfPos = 100000000.f; @@ -88,93 +88,6 @@ __forceinline__ __device__ void block_reduce(T* pval) { #undef COLOSSAL_WARP_REDUCE_IMPL #undef COLOSSAL_BLOCK_REDUCE_IMPL -template -__device__ __forceinline__ T reduce_block_into_lanes( - T* x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = x[tid] + x[tid + i]; - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = x[tid] + x[tid + 32]; - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = final + __shfl_down_sync(0xffffffff, final, i); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} - -template -__device__ __forceinline__ T reduce_block_into_lanes_max_op( - T* x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = - fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} - -} // namespace utils +} // namespace funcs } // namespace cuda } // namespace colossalAI diff --git a/extensions/csrc/cuda/funcs/unary_functor.h b/extensions/csrc/cuda/funcs/unary_functor.h index 72c421ea1..ea57fae7a 100644 --- a/extensions/csrc/cuda/funcs/unary_functor.h +++ b/extensions/csrc/cuda/funcs/unary_functor.h @@ -15,7 +15,7 @@ namespace funcs { // Note(LiuYang): As a retrieved table to check which operation is supported // already -enum class UnaryOpType { kLog2Ceil = 0 }; +enum class UnaryOpType { kLog2Ceil = 0, kAbs }; // Note(LiuYang): Implementation of common and simple unary operators should be // placed here, otherwise, they should be placed in a new file under functors @@ -31,6 +31,9 @@ struct UnaryOpFunctor; FUNCTION_MODIFIER TO operator()(FROM val) STMTS \ }; +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION( + T, T, UnaryOpType::kAbs, HOSTDEVICE, { return std::abs(val); }, typename T) + COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil, HOSTDEVICE, { int log2_value = 0; diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu index e5766e981..4f589597f 100644 --- a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -6,6 +6,9 @@ #include "../common/micros.h" #include "../common/mp_type_traits.h" +using colossalAI::cuda::utils::copy_vector; +using colossalAI::cuda::utils::get_vec_size; + template __device__ void apply_emb_rotary_compute( scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr, diff --git a/extensions/csrc/cuda/get_cos_and_sin_kernel.cu b/extensions/csrc/cuda/get_cos_and_sin_kernel.cu index 15b5c5efb..40db089b2 100644 --- a/extensions/csrc/cuda/get_cos_and_sin_kernel.cu +++ b/extensions/csrc/cuda/get_cos_and_sin_kernel.cu @@ -3,7 +3,10 @@ #include "utils/vec_copy.h" #include "../common/micros.h" -#include "stdio.h" + +using colossalAI::cuda::utils::copy_vector; +using colossalAI::cuda::utils::get_vec_size; + template __device__ void apply_cos_and_sin_memcopy( diff --git a/extensions/csrc/cuda/moe_kernel.cu b/extensions/csrc/cuda/moe_kernel.cu index 7b28dffe9..a60932c76 100644 --- a/extensions/csrc/cuda/moe_kernel.cu +++ b/extensions/csrc/cuda/moe_kernel.cu @@ -4,11 +4,11 @@ #include -#include "block_reduce.h" +#include "funcs/reduce_function.h" -using colossalAI::cuda::utils::block_reduce; -using colossalAI::cuda::utils::ReduceType; +using colossalAI::cuda::funcs::block_reduce; +using colossalAI::cuda::funcs::ReduceType; template __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { diff --git a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu index fe86a8104..d2e0f8734 100644 --- a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu @@ -12,14 +12,98 @@ #include "multi_tensor_apply.cuh" #include "../common/micros.h" -#include "include/block_reduce.h" +#include "funcs/reduce_function.h" #define BLOCK_SIZE 512 #define ILP 4 -using colossalAI::cuda::utils::block_reduce; -using colossalAI::cuda::utils::reduce_block_into_lanes; -using colossalAI::cuda::utils::reduce_block_into_lanes_max_op; + +template +__device__ __forceinline__ T reduce_block_into_lanes( + T* x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op( + T* x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = + fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} template __device__ __forceinline__ bool is_aligned(T *p) { diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 33f35ccbd..1b89232f3 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -5,39 +5,20 @@ #include #include #include -#include -#include "block_reduce.h" #include "../common/micros.h" #include "funcs/cast_functor.h" #include "funcs/binary_functor.h" +#include "funcs/reduce_function.h" +#include "utils/vec_type_traits.h" -using colossalAI::cuda::utils::block_reduce; -using colossalAI::cuda::utils::ReduceType; +using colossalAI::cuda::funcs::block_reduce; +using colossalAI::cuda::funcs::ReduceType; using colossalAI::cuda::funcs::CastFunctor; using colossalAI::cuda::funcs::BinaryOpFunctor; using colossalAI::cuda::funcs::BinaryOpType; - - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; - -#define TYPE_CONVERTER_SPECIALIZATION(FROM, TO) \ - template <> \ - struct TypeConverter { \ - using Type = TO; \ - }; - -TYPE_CONVERTER_SPECIALIZATION(half2, at::Half) -TYPE_CONVERTER_SPECIALIZATION(at::Half, half2) -TYPE_CONVERTER_SPECIALIZATION(__nv_bfloat162, at::BFloat16) -TYPE_CONVERTER_SPECIALIZATION(at::BFloat16, __nv_bfloat162) - -#undef TYPE_CONVERTER_SPECIALIZATION +using colossalAI::cuda::utils::VecTypeTrait; // optimized for half and bf16 template @@ -48,7 +29,7 @@ __global__ void rms_layernorm_kernel( const float epsilon, const int num_tokens, const int hidden_size) { - using scalar2_t = typename TypeConverter::Type; + using scalar2_t = typename VecTypeTrait::Type; BinaryOpFunctor mul_scalar2t; __shared__ float s_variance; @@ -134,7 +115,7 @@ __global__ void fused_add_rms_layernorm_kernel( const float epsilon, const int num_tokens, const int hidden_size) { - using scalar2_t = typename TypeConverter::Type; + using scalar2_t = typename VecTypeTrait::Type; BinaryOpFunctor add_scalar2t; BinaryOpFunctor mul_scalar2t; diff --git a/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu b/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu index e0bb6497a..3e51c4b66 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu +++ b/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu @@ -16,13 +16,14 @@ #include "../common/micros.h" #include "utils/vec_copy.h" -#include "include/block_reduce.h" +#include "funcs/reduce_function.h" #include "funcs/unary_functor.h" using colossalAI::cuda::funcs::UnaryOpFunctor; using colossalAI::cuda::funcs::UnaryOpType; -using colossalAI::cuda::utils::warp_reduce; -using colossalAI::cuda::utils::ReduceType; +using colossalAI::cuda::funcs::warp_reduce; +using colossalAI::cuda::funcs::ReduceType; +using colossalAI::cuda::utils::copy_vector; /* diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu index d44097b6b..510d98f28 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu @@ -16,13 +16,15 @@ #include "../common/micros.h" #include "utils/vec_copy.h" -#include "include/block_reduce.h" +#include "funcs/reduce_function.h" #include "funcs/unary_functor.h" using colossalAI::cuda::funcs::UnaryOpFunctor; using colossalAI::cuda::funcs::UnaryOpType; -using colossalAI::cuda::utils::warp_reduce; -using colossalAI::cuda::utils::ReduceType; +using colossalAI::cuda::funcs::warp_reduce; +using colossalAI::cuda::funcs::ReduceType; +using colossalAI::cuda::utils::copy_vector; +using colossalAI::cuda::utils::copy_zero_vector; /* * Extended softmax (from native aten pytorch) with following additional diff --git a/extensions/csrc/cuda/utils/vec_copy.h b/extensions/csrc/cuda/utils/vec_copy.h index 5157ec738..39e28d268 100644 --- a/extensions/csrc/cuda/utils/vec_copy.h +++ b/extensions/csrc/cuda/utils/vec_copy.h @@ -1,12 +1,16 @@ #pragma once -#include #include #include +#include "../funcs/cast_functor.h" #include "vec_type_traits.h" +namespace colossalAI { +namespace cuda { +namespace utils { + template __device__ __inline__ void copy_vector(T *dst, const T *src) { using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; @@ -26,7 +30,8 @@ __device__ __inline__ void copy_vector(float *dst, const float *src) { template __device__ __inline__ void copy_zero_vector(T *dst) { using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; - *(reinterpret_cast(dst)) = {0.0}; + *(reinterpret_cast(dst)) = + colossalAI::cuda::funcs::CastFunctor()(0.0f); } template @@ -50,3 +55,7 @@ int get_vec_size(const torch::Tensor &tensor) { return 1; } } + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/vec_type_traits.h b/extensions/csrc/cuda/utils/vec_type_traits.h index 0bd25469a..782518936 100644 --- a/extensions/csrc/cuda/utils/vec_type_traits.h +++ b/extensions/csrc/cuda/utils/vec_type_traits.h @@ -1,8 +1,9 @@ #pragma once -#include +#include #include #include +#include #include @@ -20,12 +21,14 @@ struct VecTypeTrait {}; }; VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T) -VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 2, float) -VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 4, float2) -VEC_TYPE_TRAITS_SPECIALIZATION(c10::BFloat16, 8, float4) -VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 2, float) -VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 4, float2) -VEC_TYPE_TRAITS_SPECIALIZATION(c10::Half, 8, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16) +VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162) +VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 8, float4) +VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half) +VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2, half2) +VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2) +VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4) VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2) VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4) VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, float4) diff --git a/extensions/csrc/scaled_softmax.py b/extensions/csrc/scaled_softmax.py deleted file mode 100644 index 7c220d60d..000000000 --- a/extensions/csrc/scaled_softmax.py +++ /dev/null @@ -1,190 +0,0 @@ -# This code from NVIDIA Megatron: -# with minor changes. - -import enum - -import torch -import torch.nn as nn - -from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader - -try: - from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax -except ImportError: - scaled_masked_softmax = None - scaled_upper_triang_masked_softmax = None - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - paddedcausal = 3 - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - global scaled_upper_triang_masked_softmax - if scaled_upper_triang_masked_softmax: - scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() - - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) - - return input_grads, None - - -class ScaledMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, mask, scale): - scale_t = torch.tensor([scale]) - - # build and load kernel if not pre-built - global scaled_masked_softmax - if scaled_masked_softmax is None: - scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() - - softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None, None - - -class FusedScaleMaskSoftmax(nn.Module): - """ - Fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: Flag to indicate if input in fp16 data format. - input_in_bf16: Flag to indicate if input in bf16 data format. - attn_mask_type: Attention mask type (pad or causal) - scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion - mask_func: Mask function to be applied. - softmax_in_fp32: If True, softmax in performed at fp32 precision. - scale: Scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super(FusedScaleMaskSoftmax, self).__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - assert not ( - self.input_in_fp16 and self.input_in_bf16 - ), "both fp16 and bf16 flags cannot be active at the same time." - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and mask is not None # mask tensor must not be None - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 2048: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type.value > 1: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - b, np, sq, sk = input.size() - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type.value > 1: - assert sq == sk, "causal mask is only for self attention" - - # input is 3D tensor (attn_batches, sq, sk) - input = input.view(-1, sq, sk) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - return probs.view(b, np, sq, sk) - else: - # input is 4D tensor (b, np, sq, sk) - return ScaledMaskedSoftmax.apply(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - def get_batch_per_block(self, sq, sk, b, np): - # build and load kernel if not pre-built - global scaled_masked_softmax - if scaled_masked_softmax is None: - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() - - return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)