From a2878e39f42f509f237f3d3fd0741f53e3feff0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Mon, 1 Apr 2024 15:34:25 +0800 Subject: [PATCH] [Inference] Add Reduce Utils (#5537) * add reduce utils * add using to delele namespace prefix --- extensions/csrc/common/micros.h | 10 - extensions/csrc/cuda/funcs/op_functor.h | 32 ++ extensions/csrc/cuda/include/block_reduce.h | 375 ++++-------------- extensions/csrc/cuda/layer_norm_kernel.cu | 32 +- extensions/csrc/cuda/moe_kernel.cu | 45 ++- extensions/csrc/cuda/multi_tensor_apply.cuh | 2 +- .../csrc/cuda/multi_tensor_l2norm_kernel.cu | 28 +- .../csrc/cuda/multi_tensor_lamb_kernel.cu | 6 +- extensions/csrc/cuda/rms_layernorm_kernel.cu | 11 +- 9 files changed, 179 insertions(+), 362 deletions(-) create mode 100644 extensions/csrc/cuda/funcs/op_functor.h diff --git a/extensions/csrc/common/micros.h b/extensions/csrc/common/micros.h index 12cd78046..fd489d764 100644 --- a/extensions/csrc/common/micros.h +++ b/extensions/csrc/common/micros.h @@ -9,16 +9,6 @@ #include -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif - #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) { \ case at::ScalarType::Half: { \ diff --git a/extensions/csrc/cuda/funcs/op_functor.h b/extensions/csrc/cuda/funcs/op_functor.h new file mode 100644 index 000000000..7c00bcced --- /dev/null +++ b/extensions/csrc/cuda/funcs/op_functor.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + +#include + +namespace colossalAI { +namespace cuda { +namespace funcs { + +enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin }; + +template +struct BinaryOpFunctor; + +template +struct BinaryOpFunctor + : public std::binary_function { + __host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; } +}; + +template +struct BinaryOpFunctor + : public std::binary_function { + __host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); } +}; + +} // namespace funcs +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h index 86409136b..d262091c4 100644 --- a/extensions/csrc/cuda/include/block_reduce.h +++ b/extensions/csrc/cuda/include/block_reduce.h @@ -1,319 +1,100 @@ -/* Copyright 2021 The LightSeq Team - Copyright Tencent/TurboTransformers - This block_reduce_n is adapted from Tencent/TurboTransformers -*/ #pragma once + #include #include #include +#include "../funcs/op_functor.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +const float kReduceFloatInfNeg = -100000000.f; +const float kReduceFloatInfPos = 100000000.f; +const int kWarpSize = 32; +const unsigned int kWarpReduceMask = 0xffffffff; + enum class ReduceType { kMax = 0, kSum }; -const unsigned int WARP_REDUCE_MASK = 0xffffffff; -const float REDUCE_FLOAT_INF_NEG = -100000000.f; -const float REDUCE_FLOAT_INF_POS = 100000000.f; -const unsigned int WARP_REDUCE_SIZE = 32; + +template +struct GetOpForReduceType; template -__forceinline__ __device__ T warpReduceSum(T val) { - for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) - val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); - return val; -} +struct GetOpForReduceType { + using Op = funcs::BinaryOpFunctor; +}; -/* Calculate the sum of all elements in a block */ template -__forceinline__ __device__ T blockReduceSum(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; +struct GetOpForReduceType { + using Op = funcs::BinaryOpFunctor; +}; - val = warpReduceSum(val); +#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \ + for (int offset = 0; offset < LANES; ++offset) { \ + *(VAL_PTR + offset) = \ + OP(*(VAL_PTR + offset), \ + __shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \ + } - if (lane == 0) shared[wid] = val; - __syncthreads(); +#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 16, 32, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 8, 32, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 4, 32, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 2, 32, OP, LANES) \ + COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 1, 32, OP, LANES) - val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; - val = warpReduceSum(val); - return val; +#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, MASK, VAL_PTR, OP, LANES, \ + DEFAULT_VALUE, REDUCE_TYPE) \ + __shared__ T shm[LANES][32]; \ + int lane_id = threadIdx.x & 0x1f; \ + int warp_id = threadIdx.x >> 5; \ + \ + warp_reduce(VAL_PTR); \ + if (lane_id == 0) { \ + for (int offset = 0; offset < LANES; ++offset) { \ + shm[offset][warp_id] = *(VAL_PTR + offset); \ + } \ + } \ + __syncthreads(); \ + \ + for (int offset = 0; offset < LANES; ++offset) { \ + *(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \ + ? shm[offset][lane_id] \ + : static_cast(DEFAULT_VALUE); \ + } \ + warp_reduce(VAL_PTR); + +template +__forceinline__ __device__ void warp_reduce(T* pval) { + typename GetOpForReduceType::Op op; + COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, op, lanes); } -template -__inline__ __device__ void blockReduce(float *pval); - -// use template to make code more concise -template -__inline__ __device__ void warpReduce(float *pval); - -// static -template <> -__inline__ __device__ void warpReduce(float *pval) { - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32)); - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32)); - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32)); - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32)); - *pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32)); +template +__forceinline__ __device__ constexpr T GetDefaultValueForBlockReduce() { + if constexpr (rtype == ReduceType::kSum) { + return static_cast(0.0f); + } else if constexpr (rtype == ReduceType::kMax) { + return static_cast(kReduceFloatInfNeg); + } } -template <> -__inline__ __device__ void warpReduce(float *pval) { - float val0_tmp, val1_tmp; -#define WarpReduceMaxOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - *(pval) = max(val0_tmp, *(pval)); \ - *(pval + 1) = max(val1_tmp, *(pval + 1)); - - WarpReduceMaxOneStep(16, 32); - WarpReduceMaxOneStep(8, 32); - WarpReduceMaxOneStep(4, 32); - WarpReduceMaxOneStep(2, 32); - WarpReduceMaxOneStep(1, 32); -#undef WarpReduceMaxOneStep +template +__forceinline__ __device__ void block_reduce(T* pval) { + constexpr T kDefaultValue = GetDefaultValueForBlockReduce(); + typename GetOpForReduceType::Op op; + COLOSSAL_BLOCK_REDUCE_IMPL(T, kWarpReduceMask, pval, op, lanes, kDefaultValue, + rtype); } -template <> -__inline__ __device__ void warpReduce(float *pval) { - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32); - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32); - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32); - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32); - *pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32); -} - -/* - * Unorll for loop for warpreduce to - * imporve instruction issue efficiency - * ElemX means there are X numbers to be summed - */ - -template <> -__inline__ __device__ void warpReduce(float *pval) { - float val0_tmp, val1_tmp; -#define WarpReduceSumOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - *(pval + 0) += val0_tmp; \ - *(pval + 1) += val1_tmp - - WarpReduceSumOneStep(16, 32); - WarpReduceSumOneStep(8, 32); - WarpReduceSumOneStep(4, 32); - WarpReduceSumOneStep(2, 32); - WarpReduceSumOneStep(1, 32); - -#undef WarpReduceSumOneStep -} - -template <> -__inline__ __device__ void warpReduce(float *pval) { - float val0_tmp, val1_tmp, val2_tmp, val3_tmp; -#define WarpReduceSumOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ - val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ - *(pval + 0) += val0_tmp; \ - *(pval + 1) += val1_tmp; \ - *(pval + 2) += val2_tmp; \ - *(pval + 3) += val3_tmp - - WarpReduceSumOneStep(16, 32); - WarpReduceSumOneStep(8, 32); - WarpReduceSumOneStep(4, 32); - WarpReduceSumOneStep(2, 32); - WarpReduceSumOneStep(1, 32); -#undef WarpReduceSumOneStep -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 1; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = 0.f; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 2; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = 0.f; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 4; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = 0.f; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 1; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = REDUCE_FLOAT_INF_NEG; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 1; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = REDUCE_FLOAT_INF_NEG; - } - } - warpReduce(pval); -} - -template <> -__inline__ __device__ void blockReduce(float *pval) { - const int num = 1; - static __shared__ float shared[num][32]; - int lane_id = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduce(pval); - - if (lane_id == 0) { -#pragma unroll - for (int i = 0; i < num; ++i) { - shared[i][wid] = *(pval + i); - } - } - __syncthreads(); - - if (threadIdx.x < (blockDim.x >> 5)) { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = shared[i][lane_id]; - } - } else { -#pragma unroll - for (int i = 0; i < num; ++i) { - *(pval + i) = REDUCE_FLOAT_INF_NEG; - } - } - warpReduce(pval); -} +#undef COLOSSAL_SHFL_FUNCTION +#undef COLOSSAL_WARP_REDUCE_IMPL +#undef COLOSSAL_BLOCK_REDUCE_IMPL template __device__ __forceinline__ T reduce_block_into_lanes( - T *x, T val, int lanes = 1, + T* x, T val, int lanes = 1, bool share_result = false) // lanes is intended to be <= 32. { int tid = threadIdx.x + threadIdx.y * blockDim.x; @@ -356,7 +137,7 @@ __device__ __forceinline__ T reduce_block_into_lanes( template __device__ __forceinline__ T reduce_block_into_lanes_max_op( - T *x, T val, int lanes = 1, + T* x, T val, int lanes = 1, bool share_result = false) // lanes is intended to be <= 32. { int tid = threadIdx.x + threadIdx.y * blockDim.x; @@ -397,3 +178,7 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op( return final; } + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/layer_norm_kernel.cu b/extensions/csrc/cuda/layer_norm_kernel.cu index 17d5b10f4..8239adc9f 100644 --- a/extensions/csrc/cuda/layer_norm_kernel.cu +++ b/extensions/csrc/cuda/layer_norm_kernel.cu @@ -606,11 +606,11 @@ void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, using namespace at; DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", - HostApplyLayerNorm(output->DATA_PTR(), - mean->DATA_PTR(), invvar->DATA_PTR(), - input->DATA_PTR(), n1, n2, epsilon, - gamma != NULL ? gamma->DATA_PTR() : NULL, - beta != NULL ? beta->DATA_PTR() : NULL);) + HostApplyLayerNorm(output->data_ptr(), + mean->data_ptr(), invvar->data_ptr(), + input->data_ptr(), n1, n2, epsilon, + gamma != NULL ? gamma->data_ptr() : NULL, + beta != NULL ? beta->data_ptr() : NULL);) } template @@ -633,14 +633,14 @@ void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, {part_size, n2}, input->options().dtype(at::ScalarType::Float)); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); cuComputePartGradGammaBeta<<>>( - dout, input->DATA_PTR(), n1, n2, mean, invvar, U(epsilon), - part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR()); + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr()); const dim3 threads3(32, 8, 1); const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(U); cuComputeGradGammaBeta<<>>( - part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), part_size, + part_grad_gamma.data_ptr(), part_grad_beta.data_ptr(), part_size, n1, n2, grad_gamma, grad_beta); } @@ -651,7 +651,7 @@ void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar, const dim3 threads1(32, 4, 1); int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; cuComputeGradInput<<>>( - dout, input->DATA_PTR(), n1, n2, mean, invvar, U(epsilon), gamma, + dout, input->data_ptr(), n1, n2, mean, invvar, U(epsilon), gamma, grad_input); } @@ -671,13 +671,13 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, input->scalar_type(), gamma->scalar_type(), "cuda_layer_norm_gradient_kernel", HostLayerNormGradient( - dout->DATA_PTR(), mean->DATA_PTR(), - invvar->DATA_PTR(), input, n1, n2, + dout->data_ptr(), mean->data_ptr(), + invvar->data_ptr(), input, n1, n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. - gamma != NULL ? gamma->DATA_PTR() : NULL, - gamma != NULL ? beta->DATA_PTR() : NULL, epsilon, - grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL, - gamma != NULL ? grad_beta->DATA_PTR() : NULL);) + gamma != NULL ? gamma->data_ptr() : NULL, + gamma != NULL ? beta->data_ptr() : NULL, epsilon, + grad_input->data_ptr(), + gamma != NULL ? grad_gamma->data_ptr() : NULL, + gamma != NULL ? grad_beta->data_ptr() : NULL);) } diff --git a/extensions/csrc/cuda/moe_kernel.cu b/extensions/csrc/cuda/moe_kernel.cu index 66c1e6bd2..7b28dffe9 100644 --- a/extensions/csrc/cuda/moe_kernel.cu +++ b/extensions/csrc/cuda/moe_kernel.cu @@ -6,6 +6,10 @@ #include "block_reduce.h" + +using colossalAI::cuda::utils::block_reduce; +using colossalAI::cuda::utils::ReduceType; + template __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { assert(cols % pack_size == 0); @@ -157,8 +161,7 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, BlockStore(ts_store).Store(src_row + idx, grad); } - - blockReduce(&thread_sum); + block_reduce(&thread_sum); if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); } @@ -230,7 +233,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, BlockStore(ts_store).Store(src_row2 + idx, sgrad2); } - blockReduce(thread_sum); + block_reduce(thread_sum); if (threadIdx.x == 0) *weight_grad1 = static_cast(thread_sum[0]); @@ -566,10 +569,10 @@ torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, DISPATCH_FLOAT_AND_HALF( batch_tokens.scalar_type(), "moe dispatch forward", moe_dpch_fwd_launch( - batch_tokens.data(), res.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + batch_tokens.data_ptr(), res.data_ptr(), + mask[0].data_ptr(), k == 1 ? nullptr : mask[1].data_ptr(), + dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, h)); return res; } @@ -586,10 +589,10 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, DISPATCH_FLOAT_AND_HALF( expert_grad.scalar_type(), "moe dispatch backward", moe_dpch_bwd_launch( - res.data(), expert_grad.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + res.data_ptr(), expert_grad.data_ptr(), + mask[0].data_ptr(), k == 1 ? nullptr : mask[1].data_ptr(), + dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, h)); return res; } @@ -609,10 +612,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, DISPATCH_FLOAT_AND_HALF( expert_tokens.scalar_type(), "moe combine forward", moe_cb_fwd_launch( - expert_tokens.data(), res.data(), - logits.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + expert_tokens.data_ptr(), res.data_ptr(), + logits.data_ptr(), mask[0].data_ptr(), + k == 1 ? nullptr : mask[1].data_ptr(), dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, e, c, h)); return res; @@ -636,11 +639,11 @@ std::vector moe_combine_cuda_backward( DISPATCH_FLOAT_AND_HALF( tokens_grad.scalar_type(), "moe combine backward", moe_cb_bwd_launch( - tokens_grad.data(), egrad.data(), - expert_tokens.data(), logits.data(), - wgrad.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + tokens_grad.data_ptr(), egrad.data_ptr(), + expert_tokens.data_ptr(), logits.data_ptr(), + wgrad.data_ptr(), mask[0].data_ptr(), + k == 1 ? nullptr : mask[1].data_ptr(), dest_idx[0].data_ptr(), + k == 1 ? dest_idx[0].data_ptr() : dest_idx[1].data_ptr(), s, e, c, h)); return {egrad, wgrad}; @@ -653,7 +656,7 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { const int s = mask.size(0), e = mask.size(1); auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); - cumsum_launch(mask.data(), res.data(), s, e); + cumsum_launch(mask.data_ptr(), res.data_ptr(), s, e); return res; } diff --git a/extensions/csrc/cuda/multi_tensor_apply.cuh b/extensions/csrc/cuda/multi_tensor_apply.cuh index 01a858661..799ccfa73 100644 --- a/extensions/csrc/cuda/multi_tensor_apply.cuh +++ b/extensions/csrc/cuda/multi_tensor_apply.cuh @@ -104,7 +104,7 @@ void multi_tensor_apply( if (tensors_full || blocks_full || last_chunk) { // using accscalar_t = acc_type; multi_tensor_apply_kernel<<>>( - chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); + chunk_size, noop_flag.data_ptr(), tl, callable, args...); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu index 57a79f7a8..fe86a8104 100644 --- a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu @@ -17,6 +17,10 @@ #define BLOCK_SIZE 512 #define ILP 4 +using colossalAI::cuda::utils::block_reduce; +using colossalAI::cuda::utils::reduce_block_into_lanes; +using colossalAI::cuda::utils::reduce_block_into_lanes_max_op; + template __device__ __forceinline__ bool is_aligned(T *p) { return ((uint64_t)p) % (ILP * sizeof(T)) == 0; @@ -290,8 +294,8 @@ std::tuple multi_tensor_l2norm_cuda( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - L2NormFunctor(), output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, + L2NormFunctor(), output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, per_tensor, max_chunks_per_tensor);) AT_CUDA_CHECK(cudaGetLastError()); @@ -304,10 +308,10 @@ std::tuple multi_tensor_l2norm_cuda( const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); auto stream = at::cuda::getCurrentCUDAStream(); cleanup<<>>( - output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, - ret.DATA_PTR(), - per_tensor ? ret_per_tensor.DATA_PTR() : nullptr, per_tensor, + output.data_ptr(), + per_tensor ? output_per_tensor.data_ptr() : nullptr, + ret.data_ptr(), + per_tensor ? ret_per_tensor.data_ptr() : nullptr, per_tensor, max_chunks_per_tensor); return std::tuple(ret, ret_per_tensor); @@ -350,15 +354,15 @@ void multi_tensor_norm_out_cuda( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - MaxNormFunctor(), output.DATA_PTR(), - output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + MaxNormFunctor(), output.data_ptr(), + output_per_tensor.data_ptr(), true, max_chunks_per_tensor);) } else { DISPATCH_FLOAT_AND_HALF( tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", multi_tensor_apply<1>( BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - L2NormFunctor(), output.DATA_PTR(), - output_per_tensor.DATA_PTR(), true, max_chunks_per_tensor);) + L2NormFunctor(), output.data_ptr(), + output_per_tensor.data_ptr(), true, max_chunks_per_tensor);) } AT_CUDA_CHECK(cudaGetLastError()); @@ -375,8 +379,8 @@ void multi_tensor_norm_out_cuda( const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); auto stream = at::cuda::getCurrentCUDAStream(); cleanup_v2<<>>( - output.DATA_PTR(), output_per_tensor.DATA_PTR(), - ret.DATA_PTR(), out.DATA_PTR(), true, max_chunks_per_tensor, + output.data_ptr(), output_per_tensor.data_ptr(), + ret.data_ptr(), out.data_ptr(), true, max_chunks_per_tensor, norm_type, alpha, beta); return; diff --git a/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu b/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu index 50dfc56bc..82c02f36d 100644 --- a/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu @@ -333,7 +333,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, beta3, // 1-beta1 or 1 depends on averaging mode bias_correction1, bias_correction2, epsilon, (adamMode_t)mode, weight_decay, - global_grad_norm.DATA_PTR(), max_grad_norm);) + global_grad_norm.data_ptr(), max_grad_norm);) // Compute update norms auto update_norm_tuple = @@ -346,8 +346,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list, LAMBStage2Functor(), - std::get<1>(param_norm_tuple).DATA_PTR(), - std::get<1>(update_norm_tuple).DATA_PTR(), + std::get<1>(param_norm_tuple).data_ptr(), + std::get<1>(update_norm_tuple).data_ptr(), lr, weight_decay, use_nvlamb);) AT_CUDA_CHECK(cudaGetLastError()); diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 50f26510e..9d96472bd 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -12,6 +12,9 @@ #include "../common/micros.h" #include "utils/cuda_type_utils.h" +using colossalAI::cuda::utils::block_reduce; +using colossalAI::cuda::utils::ReduceType; + #define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ if (DATA_SIZE == 2) { \ switch (TYPE) { \ @@ -77,7 +80,7 @@ __global__ void rms_layernorm_kernel( float v2 = cuda_cast(x_local[cnt].y); variance += v1 * v1 + v2 * v2; } - variance = blockReduceSum(variance); + block_reduce(&variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -111,7 +114,7 @@ __global__ void general_rms_layernorm_kernel( x_local[cnt] = (float) input[id]; variance += x_local[cnt] * x_local[cnt]; } - variance = blockReduceSum(variance); + block_reduce(&variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -154,7 +157,7 @@ __global__ void fused_add_rms_layernorm_kernel( variance += v1 * v1 + v2 * v2; residual_ptr[id] = x_local[cnt]; } - variance = blockReduceSum(variance); + block_reduce(&variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -190,7 +193,7 @@ __global__ void general_fused_add_rms_layernorm_kernel( variance += x_local[cnt] * x_local[cnt]; residual[id] = (scalar_t) x_local[cnt]; } - variance = blockReduceSum(variance); + block_reduce(&variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); }