From 7ebdf48ac50ca7bab827ef611551c6c48113b684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Mon, 8 Apr 2024 11:38:05 +0800 Subject: [PATCH] add cast and op_functor for cuda build-in types (#5546) --- extensions/csrc/cuda/funcs/cast_functor.h | 74 +++++++++++ extensions/csrc/cuda/funcs/op_functor.h | 84 +++++++++++-- extensions/csrc/cuda/include/block_reduce.h | 4 +- extensions/csrc/cuda/rms_layernorm_kernel.cu | 31 +++-- extensions/csrc/cuda/utils/cuda_type_utils.h | 122 ------------------- extensions/csrc/cuda/utils/micros.h | 4 + 6 files changed, 173 insertions(+), 146 deletions(-) create mode 100644 extensions/csrc/cuda/funcs/cast_functor.h delete mode 100644 extensions/csrc/cuda/utils/cuda_type_utils.h diff --git a/extensions/csrc/cuda/funcs/cast_functor.h b/extensions/csrc/cuda/funcs/cast_functor.h new file mode 100644 index 000000000..623e1cdeb --- /dev/null +++ b/extensions/csrc/cuda/funcs/cast_functor.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include +#include + +#include + +#include "../utils/micros.h" + +// Note(LiuYang): This file provides base math operation for data type +// include POD and cuda built-in type such as half and __nv_bfloat16 + +namespace colossalAI { +namespace cuda { +namespace funcs { + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = at::Half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = at::BFloat16; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; + +template +struct CastFunctor : public std::unary_function { + HOSTDEVICE To operator()(From val) { return static_cast(val); } +}; + +#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \ + FUNCTION_MODIFIER) \ + template <> \ + struct CastFunctor : public std::unary_function { \ + FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \ + }; + +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y), + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val), + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val), + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val), + DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, nv_bfloat162, + __float2bfloat162_rn(val), DEVICE) + +#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION +} // namespace funcs +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/funcs/op_functor.h b/extensions/csrc/cuda/funcs/op_functor.h index 7c00bcced..0398ea97b 100644 --- a/extensions/csrc/cuda/funcs/op_functor.h +++ b/extensions/csrc/cuda/funcs/op_functor.h @@ -1,31 +1,91 @@ #pragma once #include +#include #include #include #include +#include "../utils/micros.h" + namespace colossalAI { namespace cuda { namespace funcs { -enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin }; +enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin }; -template +// Note(LiuYang): This file provides base math operation for data type +// include POD and cuda built-in type such as half and __nv_bfloat16 +template struct BinaryOpFunctor; -template -struct BinaryOpFunctor - : public std::binary_function { - __host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; } -}; +#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \ + FUNCTION_MODIFIER, ARGS...) \ + template \ + struct BinaryOpFunctor \ + : public std::binary_function { \ + FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \ + }; -template -struct BinaryOpFunctor - : public std::binary_function { - __host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); } -}; +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs, + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs, + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs, + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs, + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs), + HOSTDEVICE, typename T) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs), + HOSTDEVICE, typename T) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd, + __hadd(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd, + __hadd2(lhs, rhs), DEVICE) + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd, + __hadd(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd, + __hadd2(lhs, rhs), DEVICE) +#else +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd, + __float2bfloat16(__bfloat162float(lhs) + + __bfloat162float(rhs)), + DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, BinaryOpType::kAdd, + __floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs), + __high2float(lhs) + __high2float(rhs)), + DEVICE) +#endif + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul, + __hmul(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul, + __hmul2(lhs, rhs), DEVICE) + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul, + __hmul(lhs, rhs), DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul, + __hmul2(lhs, rhs), DEVICE) +#else +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul, + __float2bfloat16(__bfloat162float(lhs) * + __bfloat162float(rhs)), + DEVICE) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat162, BinaryOpType::kMul, + __floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs), + __high2float(lhs) * __high2float(rhs)), + DEVICE) +#endif + +#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION } // namespace funcs } // namespace cuda diff --git a/extensions/csrc/cuda/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h index d262091c4..6f6db6f77 100644 --- a/extensions/csrc/cuda/include/block_reduce.h +++ b/extensions/csrc/cuda/include/block_reduce.h @@ -22,12 +22,12 @@ struct GetOpForReduceType; template struct GetOpForReduceType { - using Op = funcs::BinaryOpFunctor; + using Op = funcs::BinaryOpFunctor; }; template struct GetOpForReduceType { - using Op = funcs::BinaryOpFunctor; + using Op = funcs::BinaryOpFunctor; }; #define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \ diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 9d96472bd..c39e44d87 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -10,10 +10,15 @@ #include "block_reduce.h" #include "../common/micros.h" -#include "utils/cuda_type_utils.h" +#include "funcs/cast_functor.h" +#include "funcs/op_functor.h" using colossalAI::cuda::utils::block_reduce; using colossalAI::cuda::utils::ReduceType; +using colossalAI::cuda::funcs::TypeConverter; +using colossalAI::cuda::funcs::CastFunctor; +using colossalAI::cuda::funcs::BinaryOpFunctor; +using colossalAI::cuda::funcs::BinaryOpType; #define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ if (DATA_SIZE == 2) { \ @@ -53,6 +58,7 @@ __global__ void rms_layernorm_kernel( const int num_tokens, const int hidden_size) { using scalar2_t = typename TypeConverter::Type; + BinaryOpFunctor mul_scalar2t; __shared__ float s_variance; /* @@ -72,12 +78,13 @@ __global__ void rms_layernorm_kernel( float variance = 0.0f; int row_offset = blockIdx.x * hidden_size / 2; + #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { int id = row_offset + idx; x_local[cnt] = input_ptr[id]; - float v1 = cuda_cast(x_local[cnt].x); - float v2 = cuda_cast(x_local[cnt].y); + float v1 = CastFunctor()(x_local[cnt].x); + float v2 = CastFunctor()(x_local[cnt].y); variance += v1 * v1 + v2 * v2; } block_reduce(&variance); @@ -86,11 +93,11 @@ __global__ void rms_layernorm_kernel( } __syncthreads(); - scalar2_t s_variance_2 = cuda_cast(s_variance); + scalar2_t s_variance_2 = CastFunctor()(s_variance); #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { int id = row_offset + idx; - out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]); + out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]); } } @@ -137,6 +144,9 @@ __global__ void fused_add_rms_layernorm_kernel( const int num_tokens, const int hidden_size) { using scalar2_t = typename TypeConverter::Type; + BinaryOpFunctor add_scalar2t; + BinaryOpFunctor mul_scalar2t; + __shared__ float s_variance; scalar2_t x_local[4]; @@ -151,9 +161,9 @@ __global__ void fused_add_rms_layernorm_kernel( for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { int id = row_offset + idx; x_local[cnt] = input_ptr[id]; - x_local[cnt] = add(x_local[cnt], residual_ptr[id]); - float v1 = cuda_cast(x_local[cnt].x); - float v2 = cuda_cast(x_local[cnt].y); + x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]); + float v1 = CastFunctor()(x_local[cnt].x); + float v2 = CastFunctor()(x_local[cnt].y); variance += v1 * v1 + v2 * v2; residual_ptr[id] = x_local[cnt]; } @@ -163,11 +173,12 @@ __global__ void fused_add_rms_layernorm_kernel( } __syncthreads(); - scalar2_t s_variance_2 = cuda_cast(s_variance); + scalar2_t s_variance_2 = CastFunctor()(s_variance); + #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { int id = row_offset + idx; - input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]); + input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]); } } diff --git a/extensions/csrc/cuda/utils/cuda_type_utils.h b/extensions/csrc/cuda/utils/cuda_type_utils.h deleted file mode 100644 index 35d4c1492..000000000 --- a/extensions/csrc/cuda/utils/cuda_type_utils.h +++ /dev/null @@ -1,122 +0,0 @@ -/* - * This code from NVIDIA FasterTransformer: - * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh - */ - -#pragma once - -#include -#include - -template -inline __device__ T add(T a, T b) { - return a + b; -} - -template <> -inline __device__ half2 add(half2 a, half2 b) { - return __hadd2(a, b); -} - -template <> -inline __device__ half add(half a, half b) { - return __hadd(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { - return bf16hadd2(a, b); -} - -template <> -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { - return bf16hadd(a, b); -} - -#endif // ENABLE_BF16 - -template -inline __device__ T mul(T a, T b, T c) { - return a * b * c; -} - -template <> -inline __device__ half2 mul(half2 a, half2 b, half2 c) { - return __hmul2(__hmul2(a, b), c); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, - __nv_bfloat16 c) { - return bf16hmul(a, b, c); -} - -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, - __nv_bfloat162 c) { - return bf16hmul2(a, b, c); -} -#endif // ENABLE_BF16 - -template -__device__ inline T_OUT cuda_cast(T_IN val) { - return val; -} - -template <> -__device__ inline float2 cuda_cast(int2 val) { - return make_float2(val.x, val.y); -} -template <> -__device__ inline float2 cuda_cast(float val) { - return make_float2(val, val); -} -template <> -__device__ inline float2 cuda_cast(half2 val) { - return __half22float2(val); -} -template <> -__device__ inline half2 cuda_cast(float2 val) { - return __float22half2_rn(val); -} -template <> -__device__ inline half2 cuda_cast(float val) { - return __float2half2_rn(val); -} -template <> -__device__ inline half2 cuda_cast(half val) { - return __half2half2(val); -} -template <> -__device__ inline float cuda_cast(half val) { - return __half2float(val); -} - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = at::Half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -#if ENABLE_BF16 -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = at::BFloat16; -}; - -template <> -struct TypeConverter { - using Type = __nv_bfloat162; -}; -#endif // ENABLE_BF16 diff --git a/extensions/csrc/cuda/utils/micros.h b/extensions/csrc/cuda/utils/micros.h index 8dd8be166..aaa2fc1ef 100644 --- a/extensions/csrc/cuda/utils/micros.h +++ b/extensions/csrc/cuda/utils/micros.h @@ -12,3 +12,7 @@ throw std::runtime_error(cudaGetErrorString(status)); \ } \ } + +#define HOST __host__ +#define DEVICE __device__ +#define HOSTDEVICE __host__ __device__