From 95c21498d4f6e640e218f4b00349020f4ae7c69a Mon Sep 17 00:00:00 2001 From: xs_courtesy Date: Thu, 7 Mar 2024 16:57:49 +0800 Subject: [PATCH 1/9] add silu_and_mul for infer --- extensions/csrc/cuda/activation_kernel.cu | 65 +++++++++++++++++++ .../cuda/colossal_inference_C_frontend.cpp | 3 + extensions/csrc/cuda/include/mp_type_traits.h | 35 ++++++++++ extensions/csrc/cuda/type_shim.h | 3 + extensions/inference/inference_ops_cuda.py | 1 + .../test_ops/cuda/test_silu_and_mul.py | 33 ++++++++++ 6 files changed, 140 insertions(+) create mode 100644 extensions/csrc/cuda/activation_kernel.cu create mode 100644 extensions/csrc/cuda/include/mp_type_traits.h create mode 100644 tests/test_infer/test_ops/cuda/test_silu_and_mul.py diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu new file mode 100644 index 000000000..4121b67fc --- /dev/null +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -0,0 +1,65 @@ +#include +#include +#include + +#include "type_shim.h" +#include "include/mp_type_traits.h" + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + using MT = typename infer::dtype::MPTypeTrait::Type; + return static_cast((static_cast(x)) / (static_cast(1.0f) + expf(static_cast(-x)))); +} + +template +__global__ void act_and_mul_kernel( + const scalar_t* __restrict__ ins_data, + scalar_t* __restrict__ outs_data, + const int64_t numel) { + using MT = typename infer::dtype::MPTypeTrait::Type; + + int64_t idx = static_cast(threadIdx.x) + static_cast(blockIdx.x) * static_cast(blockDim.x); + const int64_t grid_size = blockDim.x * gridDim.x; + if(idx > numel) { + return; + } + + for(int64_t i = idx; i < numel; i += grid_size) { + scalar_t x = ins_data[i]; + scalar_t y = ins_data[i+numel]; + outs_data[i] = static_cast(static_cast(ACT_FN(x)) * static_cast(y)); + } +} + +// Note(LiuYang):This func is designed for calculation mode like +// silu(x[:half_1stdim]) * (x[half_1stdim:]) +torch::Tensor silu_and_mul(const torch::Tensor& ins) +{ + auto ins_shape = ins.sizes().vec(); + + ins_shape[0] = ins_shape[0]/2; + auto outs = torch::zeros(ins_shape,ins.options()); + auto outs_shape = ins.sizes().vec(); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Note(Liuyang): numel of ins must be divisible by 2 + int64_t numel = ((torch::numel(ins)) >> 1); + + // TODO(LiuYang): Maybe we need to implement a function to get launch config + dim3 grid((numel+255)/256); + dim3 block(256); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + ins.scalar_type(), + "silu_and_mul", + act_and_mul_kernel><<>>( + ins.data_ptr(), + outs.data_ptr(), + numel + );) + + AT_CUDA_CHECK(cudaGetLastError()); + return outs; +} diff --git a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp index ae410c14f..cc53d8b88 100644 --- a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp +++ b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp @@ -9,7 +9,10 @@ void decode_kv_cache_memcpy( torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& block_tables); // [batch_size, max_seq_len] +torch::Tensor silu_and_mul(const torch::Tensor& ins); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); + m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); } diff --git a/extensions/csrc/cuda/include/mp_type_traits.h b/extensions/csrc/cuda/include/mp_type_traits.h new file mode 100644 index 000000000..6b3ae9c1b --- /dev/null +++ b/extensions/csrc/cuda/include/mp_type_traits.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +#include "../type_shim.h" + +namespace infer { +namespace dtype { + +template +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +} // namespace dtype +} // namespace infer diff --git a/extensions/csrc/cuda/type_shim.h b/extensions/csrc/cuda/type_shim.h index 511631935..7be3fab1b 100644 --- a/extensions/csrc/cuda/type_shim.h +++ b/extensions/csrc/cuda/type_shim.h @@ -4,6 +4,9 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85 Licensed under the MIT License. */ + +#pragma once + #include #include "compat.h" diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index 12bec6fab..2858d7160 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension): for fname in [ "cuda/colossal_inference_C_frontend.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", + "cuda/activation_kernel.cu", ] ] return ret diff --git a/tests/test_infer/test_ops/cuda/test_silu_and_mul.py b/tests/test_infer/test_ops/cuda/test_silu_and_mul.py new file mode 100644 index 000000000..ced2db7ca --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_silu_and_mul.py @@ -0,0 +1,33 @@ +import pytest +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device + +inference_ops = InferenceOpsLoader().load() + + +@pytest.mark.parametrize("SHAPE_X", [2]) +@pytest.mark.parametrize("SHAPE_Y", [64]) +@pytest.mark.parametrize("SHAPE_Z", [11008]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype): + torch.manual_seed(5) + device = get_current_device() + ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device) + origin_input = ref_input.clone() + + act_out = torch.nn.functional.silu(ref_input[0], inplace=True) + ref_out = act_out * ref_input[1] + + origin_out = inference_ops.silu_and_mul(origin_input) + + if dtype == torch.float32: + assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5) + else: + assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + test_silu_and_mul(2, 64, 11008, torch.float32) + test_silu_and_mul(2, 64, 11008, torch.float16) From a46598ac5984c7dc5804d0cf8621698f1a6a8720 Mon Sep 17 00:00:00 2001 From: xs_courtesy Date: Fri, 8 Mar 2024 14:53:29 +0800 Subject: [PATCH 2/9] add reusable utils for cuda --- extensions/csrc/common/dev_info_mgr.h | 20 +++ extensions/csrc/common/target.h | 134 ++++++++++++++++++ .../csrc/cuda/utils/gpu_launch_config.h | 36 +++++ extensions/csrc/cuda/utils/micros.h | 12 ++ extensions/csrc/cuda/utils/nvgpu_dev_info.cc | 45 ++++++ extensions/csrc/cuda/utils/nvgpu_dev_info.h | 37 +++++ 6 files changed, 284 insertions(+) create mode 100644 extensions/csrc/common/dev_info_mgr.h create mode 100644 extensions/csrc/common/target.h create mode 100644 extensions/csrc/cuda/utils/gpu_launch_config.h create mode 100644 extensions/csrc/cuda/utils/micros.h create mode 100644 extensions/csrc/cuda/utils/nvgpu_dev_info.cc create mode 100644 extensions/csrc/cuda/utils/nvgpu_dev_info.h diff --git a/extensions/csrc/common/dev_info_mgr.h b/extensions/csrc/common/dev_info_mgr.h new file mode 100644 index 000000000..7570666ad --- /dev/null +++ b/extensions/csrc/common/dev_info_mgr.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "common/nvgpu_dev_info.h" +#include "target.h" + +namespace colossalAI { +namespace common { + +template +class DevInfoMgr final { + public: + static std::unique_ptr GetDevInfo(int device_num) const { + return std::make_unique(device_num); + } +}; + +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/common/target.h b/extensions/csrc/common/target.h new file mode 100644 index 000000000..1c8a508e3 --- /dev/null +++ b/extensions/csrc/common/target.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include + +namespace colossalAI { +namespace common { + +class Target { + public: + enum class OS : int { + Unk = -1, + Linux, + Windows, + }; + enum class Arch : int { + Unk = -1, + X86, + Arm, + NVGPU, + AMDGPU, + Ascend, + }; + enum class BitLen : int { + Unk = -1, + k32, + k64, + }; + + explicit Target(OS os, Arch arch, BitLen bitlen) + : os_(os), arch_(arch), bitlen_(bitlen) {} + + bool defined() const { + return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk); + } + + std::string str() const { + std::string s{"OS: "}; + switch (os_) { + case OS::Unk: + s += "Unk"; + break; + case OS::Linux: + s += "Linux"; + break; + case OS::Windows: + s += "Windows"; + break; + default: + throw std::invalid_argument("Invalid OS type!"); + } + s += "\t"; + s += "Arch: "; + + switch (arch_) { + case Arch::Unk: + s += "Unk"; + break; + case Arch::X86: + s += "X86"; + break; + case Arch::Arm: + s += "Arm"; + break; + case Arch::NVGPU: + s += "NVGPU"; + break; + case Arch::AMDGPU: + s += "AMDGPU"; + break; + case Arch::Ascend: + s += "Ascend"; + break; + default: + throw std::invalid_argument("Invalid Arch type!"); + } + s += "\t"; + s += "BitLen: "; + + switch (bitlen_) { + case BitLen::Unk: + s += "Unk"; + break; + case BitLen::k32: + s += "k32"; + break; + case BitLen::k64: + s += "k64"; + break; + default: + throw std::invalid_argument("Invalid target bit length!"); + } + + return s; + } + + OS os() const { return os_; } + Arch arch() const { return arch_; } + BitLen bitlen() const { return bitlen_; } + + static Target DefaultX86Target(); + static Target DefaultArmTarget(); + static Target DefaultRocmTarget(); + static Target DefaultAscendTarget(); + + static Target DefaultCUDATarget() { + return Target(OS::Linux, Arch::CUDA, BitLen::k64); + } + + friend std::ostream& operator<<(std::ostream& os, const Target& target); + friend bool operator==(const Target& lhs, const Target& rhs); + friend bool operator!=(const Target& lhs, const Target& rhs); + + private: + OS os_{OS::Unk}; + Arch arch_{Arch::Unk}; + BitLen bitlen_{BitLen::Unk}; +}; + +std::ostream& operator<<(std::ostream& os, const Target& target) { + std::cout << target.str() << std::endl; +} +bool operator==(const Target& lhs, const Target& rhs) { + return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) && + (lhs.bitlen_ == rhs.bitlen_); +} +bool operator!=(const Target& lhs, const Target& rhs) { + return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) && + (lhs.bitlen_ != rhs.bitlen_); +} + +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/gpu_launch_config.h b/extensions/csrc/cuda/utils/gpu_launch_config.h new file mode 100644 index 000000000..c7481323a --- /dev/null +++ b/extensions/csrc/cuda/utils/gpu_launch_config.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +namespace colossalAI { +namespace cuda { +namespace utils { + +GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size); + +// TODO(LiuYang): to be implemented +GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size); + +// TODO(LiuYang): to be implemented +GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size); + +class GPULaunchConfig { + public: + GPULaunchConfig(){}; + GPULaunchConfig(const dim3& block, const dim3& grid) + : block_(block), grid_(grid) {} + friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size); + + protected: + void set_block(const dim3& dim) { block_ = dim; } + void set_grid(const dim3& dim) { grid_ = dim; } + + private: + dim3 block_(1, 1, 1); + dim3 grid_(1, 1, 1); +} + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/micros.h b/extensions/csrc/cuda/utils/micros.h new file mode 100644 index 000000000..9b410e3d8 --- /dev/null +++ b/extensions/csrc/cuda/utils/micros.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +#define CUDA_CHECK(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \ + } \ + } diff --git a/extensions/csrc/cuda/utils/nvgpu_dev_info.cc b/extensions/csrc/cuda/utils/nvgpu_dev_info.cc new file mode 100644 index 000000000..e52abebff --- /dev/null +++ b/extensions/csrc/cuda/utils/nvgpu_dev_info.cc @@ -0,0 +1,45 @@ +#include "nvgpu_dev_info.h" + +#include + +namespace colossalAI { +namespace cuda { +namespace utils { + +std::array NVGPUDevInfo::GetMaxGridDims() const { + std::array ret; + ret[0] = prop_->maxGridSize[0]; + ret[1] = prop_->maxGridSize[1]; + ret[2] = prop_->maxGridSize[2]; + return ret; +} + +std::array NVGPUDevInfo::GetMaxBlockDims() const { + std::array ret; + ret[0] = prop_->maxThreadsDim[0]; + ret[1] = prop_->maxThreadsDim[1]; + ret[2] = prop_->maxThreadsDim[2]; + return ret; +} + +std::array NVGPUDevInfo::GetCapability() const { + std::array ret; + ret[0] = prop_.major; + ret[1] = prop_.minor; +} + +int NVGPUDevInfo::GetMultiProcessorCount() const { + return prop_->multiProcessorCount; +} + +int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const { + return prop_->maxThreadsPerMultiProcessor; +} + +int NVGPUDevInfo::GetMaxThreadsPerBlock() const { + return prop_->maxThreadsPerBlock; +} + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/nvgpu_dev_info.h b/extensions/csrc/cuda/utils/nvgpu_dev_info.h new file mode 100644 index 000000000..c8c67c908 --- /dev/null +++ b/extensions/csrc/cuda/utils/nvgpu_dev_info.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "micros.h" +#include "target.h" + +namespace colossalAI { +namespace cuda { +namespace utils { + +class NVGPUDevInfo { + public: + explicit NVGPUDevInfo(int device_num) : device_num_(device_num) { + CUDA_CALL(cudaGetDeviceProperties(prop_, device)); + } + + std::array GetMaxGridDims() const; + std::array GetMaxBlockDims() const; + std::array GetCapability() const; + int GetMultiProcessorCount() const; + int GetMaxThreadsPerMultiProcessor() const; + int GetMaxThreadsPerBlock() const; + + private: + int device_num_; + cudaDeviceProp* prop_; +}; + +} // namespace utils +} // namespace cuda +} // namespace colossalAI From 5eb5ff1464311ac16c29307d03a3c076aced7e03 Mon Sep 17 00:00:00 2001 From: xs_courtesy Date: Fri, 8 Mar 2024 15:41:14 +0800 Subject: [PATCH 3/9] refactor code --- .../{cuda/type_shim.h => common/micros.h} | 97 ++----------------- .../{cuda/include => common}/mp_type_traits.h | 10 +- extensions/csrc/cuda/activation_kernel.cu | 8 +- extensions/csrc/cuda/compat.h | 10 -- .../cuda/decode_kv_cache_memcpy_kernel.cu | 2 +- extensions/csrc/cuda/include/block_reduce.h | 87 +++++++++++++++++ extensions/csrc/cuda/layer_norm_cuda.cpp | 2 +- .../csrc/cuda/layer_norm_cuda_kernel.cu | 2 +- extensions/csrc/cuda/multi_tensor_adam.cu | 2 +- extensions/csrc/cuda/multi_tensor_apply.cuh | 2 +- .../csrc/cuda/multi_tensor_l2norm_kernel.cu | 3 +- extensions/csrc/cuda/multi_tensor_lamb.cu | 2 +- .../csrc/cuda/multi_tensor_scale_kernel.cu | 2 +- .../csrc/cuda/multi_tensor_sgd_kernel.cu | 2 +- .../csrc/cuda/scaled_masked_softmax_cuda.cu | 2 +- ...scaled_upper_triang_masked_softmax_cuda.cu | 2 +- 16 files changed, 117 insertions(+), 118 deletions(-) rename extensions/csrc/{cuda/type_shim.h => common/micros.h} (87%) rename extensions/csrc/{cuda/include => common}/mp_type_traits.h (75%) diff --git a/extensions/csrc/cuda/type_shim.h b/extensions/csrc/common/micros.h similarity index 87% rename from extensions/csrc/cuda/type_shim.h rename to extensions/csrc/common/micros.h index 7be3fab1b..c2241029f 100644 --- a/extensions/csrc/cuda/type_shim.h +++ b/extensions/csrc/common/micros.h @@ -9,7 +9,15 @@ #include -#include "compat.h" +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) { \ @@ -214,90 +222,3 @@ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ "'"); \ } - -template -__device__ __forceinline__ T reduce_block_into_lanes( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = x[tid] + x[tid + i]; - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = x[tid] + x[tid + 32]; - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = final + __shfl_down_sync(0xffffffff, final, i); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} - -template -__device__ __forceinline__ T reduce_block_into_lanes_max_op( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = - fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} diff --git a/extensions/csrc/cuda/include/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h similarity index 75% rename from extensions/csrc/cuda/include/mp_type_traits.h rename to extensions/csrc/common/mp_type_traits.h index 6b3ae9c1b..8ede2d448 100644 --- a/extensions/csrc/cuda/include/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -2,10 +2,10 @@ #include -#include "../type_shim.h" +#include "micros.h" -namespace infer { -namespace dtype { +namespace colossalAI { +namespace common { template class MPTypeTrait { @@ -31,5 +31,5 @@ class MPTypeTrait { using Type = float; }; -} // namespace dtype -} // namespace infer +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index 4121b67fc..5213a2313 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -2,13 +2,13 @@ #include #include -#include "type_shim.h" -#include "include/mp_type_traits.h" +#include "../common/micros.h" +#include "../common/mp_type_traits.h" template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - using MT = typename infer::dtype::MPTypeTrait::Type; + using MT = typename colossalAI::common::MPTypeTrait::Type; return static_cast((static_cast(x)) / (static_cast(1.0f) + expf(static_cast(-x)))); } @@ -17,7 +17,7 @@ __global__ void act_and_mul_kernel( const scalar_t* __restrict__ ins_data, scalar_t* __restrict__ outs_data, const int64_t numel) { - using MT = typename infer::dtype::MPTypeTrait::Type; + using MT = typename colossalAI::common::MPTypeTrait::Type; int64_t idx = static_cast(threadIdx.x) + static_cast(blockIdx.x) * static_cast(blockDim.x); const int64_t grid_size = blockDim.x * gridDim.x; diff --git a/extensions/csrc/cuda/compat.h b/extensions/csrc/cuda/compat.h index a62beef91..e69de29bb 100644 --- a/extensions/csrc/cuda/compat.h +++ b/extensions/csrc/cuda/compat.h @@ -1,10 +0,0 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 86db90c8b..15e613e35 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -2,7 +2,7 @@ #include #include -#include "type_shim.h" +#include "../common/micros.h" template __global__ void decode_kv_cache_memcpy_kernel( diff --git a/extensions/csrc/cuda/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h index 38103c173..86409136b 100644 --- a/extensions/csrc/cuda/include/block_reduce.h +++ b/extensions/csrc/cuda/include/block_reduce.h @@ -310,3 +310,90 @@ __inline__ __device__ void blockReduce(float *pval) { } warpReduce(pval); } + +template +__device__ __forceinline__ T reduce_block_into_lanes( + T *x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op( + T *x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = + fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} diff --git a/extensions/csrc/cuda/layer_norm_cuda.cpp b/extensions/csrc/cuda/layer_norm_cuda.cpp index 15a07bb0c..3439e5e71 100644 --- a/extensions/csrc/cuda/layer_norm_cuda.cpp +++ b/extensions/csrc/cuda/layer_norm_cuda.cpp @@ -7,7 +7,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" namespace { diff --git a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu index 72b84d6ca..17d5b10f4 100644 --- a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu +++ b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu @@ -9,7 +9,7 @@ #include "ATen/AccumulateType.h" #include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/DeviceUtils.cuh" -#include "type_shim.h" +#include "../common/micros.h" template __device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { diff --git a/extensions/csrc/cuda/multi_tensor_adam.cu b/extensions/csrc/cuda/multi_tensor_adam.cu index 9cc3ae1ea..b7793b364 100644 --- a/extensions/csrc/cuda/multi_tensor_adam.cu +++ b/extensions/csrc/cuda/multi_tensor_adam.cu @@ -15,7 +15,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_apply.cuh b/extensions/csrc/cuda/multi_tensor_apply.cuh index ec55dd320..01a858661 100644 --- a/extensions/csrc/cuda/multi_tensor_apply.cuh +++ b/extensions/csrc/cuda/multi_tensor_apply.cuh @@ -12,7 +12,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" // #include diff --git a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu index 85f935152..57a79f7a8 100644 --- a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu @@ -11,7 +11,8 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" +#include "include/block_reduce.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_lamb.cu b/extensions/csrc/cuda/multi_tensor_lamb.cu index 63771cf40..50dfc56bc 100644 --- a/extensions/csrc/cuda/multi_tensor_lamb.cu +++ b/extensions/csrc/cuda/multi_tensor_lamb.cu @@ -10,7 +10,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_scale_kernel.cu b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu index 2f58a0f16..0dec1d5d1 100644 --- a/extensions/csrc/cuda/multi_tensor_scale_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu @@ -10,7 +10,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu index 7f48dbd5d..d0cf786f8 100644 --- a/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu @@ -7,7 +7,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" #include "multi_tensor_apply.cuh" #define BLOCK_SIZE 512 diff --git a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu index 41781ebc7..2f968d30f 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu +++ b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu @@ -10,7 +10,7 @@ #include #include "scaled_masked_softmax.h" -#include "type_shim.h" +#include "../common/micros.h" namespace multihead_attn { namespace fused_softmax { diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu index 62c56e6f7..d9550dc2c 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu @@ -10,7 +10,7 @@ #include #include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" +#include "../common/micros.h" namespace multihead_attn { namespace fused_softmax { From f7aecc0c6bac001d10c1dd00274e0152e4c86df6 Mon Sep 17 00:00:00 2001 From: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:21:12 +0800 Subject: [PATCH 4/9] feat rmsnorm cuda kernel and add unittest, benchmark script (#5417) --- .../modeling/models/nopadding_llama.py | 28 +++- .../modeling/policy/nopadding_llama.py | 35 +---- ...rmsnorm_triton.py => benchmark_rmsnorm.py} | 19 ++- .../cuda/colossal_inference_C_frontend.cpp | 17 +++ extensions/csrc/cuda/rms_layernorm_kernel.cu | 126 ++++++++++++++++++ extensions/inference/inference_ops_cuda.py | 3 +- tests/test_infer/test_inference_engine.py | 14 +- .../test_ops/cuda/test_rms_layernorm.py | 51 +++++++ 8 files changed, 244 insertions(+), 49 deletions(-) rename examples/inference/benchmark_ops/{benchmark_rmsnorm_triton.py => benchmark_rmsnorm.py} (79%) create mode 100644 extensions/csrc/cuda/rms_layernorm_kernel.cu create mode 100644 tests/test_infer/test_ops/cuda/test_rms_layernorm.py diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 876fed456..f84abab4b 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import ( LlamaForCausalLM, LlamaMLP, LlamaModel, + LlamaRMSNorm, ) from colossalai.inference.batch_bucket import BatchBucket @@ -19,6 +20,7 @@ from colossalai.kernel.triton import ( decoding_fused_rotary_embedding, flash_decoding_attention, get_xine_cache, + rms_layernorm, rotary_embedding, ) from colossalai.logging import get_dist_logger @@ -124,7 +126,7 @@ def llama_model_forward( hidden_states = hidden_states[last_token_indexs - 1].contiguous() residual = residual[last_token_indexs - 1].contiguous() norm_output = torch.empty_like(hidden_states) - hidden_states, _ = self.norm(hidden_states, norm_output, residual) + hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel) return hidden_states @@ -167,7 +169,7 @@ def llama_decoder_layer_forward( use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. """ - hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual) + hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) # Self Attention hidden_states = self.self_attn( hidden_states=hidden_states, @@ -185,12 +187,32 @@ def llama_decoder_layer_forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) hidden_states = self.mlp(hidden_states) return hidden_states, residual +def llama_rmsnorm_forward( + self: LlamaRMSNorm, + hidden_states: torch.Tensor, + norm_output: torch.Tensor, + residual: torch.Tensor = None, + use_cuda_kernel: bool = True, +): + if use_cuda_kernel: + if residual is not None: + inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon) + return hidden_states, residual + + if norm_output is None: + norm_output = torch.empty_like(hidden_states) + inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon) + return norm_output, hidden_states + else: + return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual) + + class NopadLlamaAttention(LlamaAttention): def __init__( self, diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py index 13695b835..bb9a22b41 100644 --- a/colossalai/inference/modeling/policy/nopadding_llama.py +++ b/colossalai/inference/modeling/policy/nopadding_llama.py @@ -1,6 +1,5 @@ from functools import partial -import torch from torch.nn import Parameter from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm @@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import ( llama_causal_lm_forward, llama_decoder_layer_forward, llama_model_forward, + llama_rmsnorm_forward, ) from colossalai.inference.utils import init_to_get_rotary from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription @@ -17,27 +17,6 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, # import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -try: - from colossalai.kernel.triton import rms_layernorm - - HAS_TRITON_RMSNORM = True -except: - print("you should install triton from https://github.com/openai/triton") - HAS_TRITON_RMSNORM = False - - -def get_triton_rmsnorm_forward(): - if HAS_TRITON_RMSNORM: - - def _triton_rmsnorm_forward( - self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None - ): - return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual) - - return _triton_rmsnorm_forward - else: - return None - class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): def __init__(self) -> None: @@ -84,15 +63,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy): description=method_replacement, policy=policy, target_key=LlamaDecoderLayer ) - infer_forward = None - if HAS_TRITON_RMSNORM: - infer_forward = get_triton_rmsnorm_forward() - - if infer_forward is not None: - method_replacement = {"forward": partial(infer_forward)} - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=LlamaRMSNorm - ) + infer_forward = llama_rmsnorm_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm) return policy diff --git a/examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py b/examples/inference/benchmark_ops/benchmark_rmsnorm.py similarity index 79% rename from examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py rename to examples/inference/benchmark_ops/benchmark_rmsnorm.py index 9c60601b9..3b5166af0 100644 --- a/examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py +++ b/examples/inference/benchmark_ops/benchmark_rmsnorm.py @@ -1,14 +1,14 @@ import torch -import triton +from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import rms_layernorm try: import triton # noqa - except ImportError: print("please install triton from https://github.com/openai/triton") +inference_ops = InferenceOpsLoader().load() # Triton benchmark plot attributions configs = [ @@ -19,16 +19,20 @@ configs = [ line_vals=[ "vllm_rms_layernorm", "triton_rms_layernorm", - "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm", "vllm_rms_layernorm_with_residual", + "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm_with_residual", ], line_names=[ "vllm_rms_layernorm", "triton_rms_layernorm", - "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm", "vllm_rms_layernorm_with_residual", + "triton_rms_layernorm_with_residual", + "cuda_rms_layernorm_with_residual", ], - styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")], ylabel="ms", plot_name=f"RMSNorm benchmarking results", args={"HIDDEN_SIZE": 1024}, @@ -62,10 +66,15 @@ def benchmark_rms_layernorm( fn = lambda: vllm_norm(x) elif provider == "triton_rms_layernorm": fn = lambda: rms_layernorm(x, weight, eps=eps) + elif provider == "cuda_rms_layernorm": + out = torch.empty_like(x) + fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps) elif provider == "vllm_rms_layernorm_with_residual": fn = lambda: vllm_norm(x, residual=residual) elif provider == "triton_rms_layernorm_with_residual": fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual) + elif provider == "cuda_rms_layernorm_with_residual": + fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps) else: raise ValueError("Undefined provider.") diff --git a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp index cc53d8b88..73ed49e6c 100644 --- a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp +++ b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp @@ -11,8 +11,25 @@ void decode_kv_cache_memcpy( torch::Tensor silu_and_mul(const torch::Tensor& ins); +void rms_layernorm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon); + +void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); + m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); + + m.def("rms_layernorm", &rms_layernorm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); + + m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm, + "In-place fused Add and RMS Normalization."); } diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu new file mode 100644 index 000000000..99d36575d --- /dev/null +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -0,0 +1,126 @@ +/*This code from VLLM: + * https://github.com/vllm-project/vllm/ + * with minor changes. */ + +#include +#include +#include +#include + + +#include "block_reduce.h" +#include "type_shim.h" + +template +__global__ void rms_layernorm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + /* + * since the open-sourced LLM's hidden dimensions mainly range from + * 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported + * hidden dimension limit to 8192, and each thread's capacity + * for caching input tensors to 8 (8192 = 8 * 1024) which + * will cause problems for extremely large models, such as + * Megatron-Turing NLG 530B with hidden dimensions up to 20480 + */ + float x_local[8]; + + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; + variance += x_local[cnt] * x_local[cnt]; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + } +} + +template +__global__ void fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + float x_local[8]; + + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; + x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx]; + variance += x_local[cnt] * x_local[cnt]; + residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt]; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + } +} + +void rms_layernorm( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) +} + +void fused_add_rms_layernorm( + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) +} diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index 2858d7160..042c598fb 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -13,12 +13,13 @@ class InferenceOpsCudaExtension(_CudaExtension): "cuda/colossal_inference_C_frontend.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", "cuda/activation_kernel.cu", + "cuda/rms_layernorm_kernel.cu", ] ] return ret def include_dirs(self): - ret = [self.get_cuda_home_include()] + ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] return ret def cxx_flags(self): diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index edd92bb96..25b2c2f43 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -22,15 +22,11 @@ def setup_seed(seed): def check_inference_engine(use_engine=False, prompt_template=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - model = ( - LlamaForCausalLM( - LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 - ) + model = LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 ) - .cuda() - .half() - ) + ).cuda() model = model.eval() inputs = [ @@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None): top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template) + inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) diff --git a/tests/test_infer/test_ops/cuda/test_rms_layernorm.py b/tests/test_infer/test_ops/cuda/test_rms_layernorm.py new file mode 100644 index 000000000..d14010600 --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_rms_layernorm.py @@ -0,0 +1,51 @@ +import pytest +import torch +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device + +inference_ops = InferenceOpsLoader().load() + + +@pytest.mark.parametrize("M", [2, 4, 8, 16]) +@pytest.mark.parametrize("N", [64, 128, 512]) +def test_rms_layernorm(M: int, N: int): + torch.manual_seed(123) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + device = get_current_device() + + dtype = torch.float16 + eps = 1e-5 + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.ones(w_shape, dtype=dtype, device=device) + residual = torch.rand(x_shape, dtype=dtype, device=device) + residual_copy = residual.clone() + rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda() + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + x_copy = x.clone() + + y_cuda = torch.empty_like(x) + inference_ops.rms_layernorm(y_cuda, x, weight, eps) + y_llama = rms_norm.forward(x).to(dtype) + + assert y_cuda.shape == y_llama.shape + assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3) + + inference_ops.fused_add_rms_layernorm(x, residual, weight, eps) + y_cuda = x + + x = x_copy + residual_copy + y_llama = rms_norm.forward(x).to(dtype) + + assert y_cuda.shape == y_llama.shape + assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3) + assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3) + + +if __name__ == "__main__": + test_rms_layernorm(16, 512) From 095c070a6eefe1a76fe3483b21986826114d6d17 Mon Sep 17 00:00:00 2001 From: xs_courtesy Date: Mon, 11 Mar 2024 17:06:57 +0800 Subject: [PATCH 5/9] refactor code --- extensions/cpu_adam/cpu_adam_x86.py | 2 +- extensions/csrc/cuda/compat.h | 0 .../{layer_norm_cuda_kernel.cu => layer_norm_kernel.cu} | 0 extensions/csrc/cuda/{moe_cuda_kernel.cu => moe_kernel.cu} | 0 .../{multi_tensor_adam.cu => multi_tensor_adam_kernel.cu} | 0 .../{multi_tensor_lamb.cu => multi_tensor_lamb_kernel.cu} | 0 .../inference.cpp} | 0 .../cuda/{layer_norm_cuda.cpp => pybind/layer_norm.cpp} | 0 extensions/csrc/cuda/{moe_cuda.cpp => pybind/moe.cpp} | 0 .../cuda/{colossal_C_frontend.cpp => pybind/optimizer.cpp} | 0 extensions/csrc/cuda/{ => pybind}/scaled_masked_softmax.cpp | 0 .../{ => pybind}/scaled_upper_triang_masked_softmax.cpp | 0 extensions/csrc/cuda/rms_layernorm_kernel.cu | 2 +- ...sked_softmax_cuda.cu => scaled_masked_softmax_kernel.cu} | 0 ...cuda.cu => scaled_upper_triang_masked_softmax_kernel.cu} | 0 extensions/csrc/{cuda => x86}/cpu_adam.cpp | 0 extensions/csrc/{cuda => x86}/cpu_adam.h | 0 extensions/inference/inference_ops_cuda.py | 2 +- extensions/layernorm/layernorm_cuda.py | 2 +- extensions/moe/moe_cuda.py | 2 +- extensions/optimizer/fused_optimizer_cuda.py | 6 +++--- extensions/softmax/scaled_masked_softmax_cuda.py | 2 +- .../softmax/scaled_upper_triangle_masked_softmax_cuda.py | 4 ++-- 23 files changed, 11 insertions(+), 11 deletions(-) delete mode 100644 extensions/csrc/cuda/compat.h rename extensions/csrc/cuda/{layer_norm_cuda_kernel.cu => layer_norm_kernel.cu} (100%) rename extensions/csrc/cuda/{moe_cuda_kernel.cu => moe_kernel.cu} (100%) rename extensions/csrc/cuda/{multi_tensor_adam.cu => multi_tensor_adam_kernel.cu} (100%) rename extensions/csrc/cuda/{multi_tensor_lamb.cu => multi_tensor_lamb_kernel.cu} (100%) rename extensions/csrc/cuda/{colossal_inference_C_frontend.cpp => pybind/inference.cpp} (100%) rename extensions/csrc/cuda/{layer_norm_cuda.cpp => pybind/layer_norm.cpp} (100%) rename extensions/csrc/cuda/{moe_cuda.cpp => pybind/moe.cpp} (100%) rename extensions/csrc/cuda/{colossal_C_frontend.cpp => pybind/optimizer.cpp} (100%) rename extensions/csrc/cuda/{ => pybind}/scaled_masked_softmax.cpp (100%) rename extensions/csrc/cuda/{ => pybind}/scaled_upper_triang_masked_softmax.cpp (100%) rename extensions/csrc/cuda/{scaled_masked_softmax_cuda.cu => scaled_masked_softmax_kernel.cu} (100%) rename extensions/csrc/cuda/{scaled_upper_triang_masked_softmax_cuda.cu => scaled_upper_triang_masked_softmax_kernel.cu} (100%) rename extensions/csrc/{cuda => x86}/cpu_adam.cpp (100%) rename extensions/csrc/{cuda => x86}/cpu_adam.h (100%) diff --git a/extensions/cpu_adam/cpu_adam_x86.py b/extensions/cpu_adam/cpu_adam_x86.py index a38194167..27b06bb65 100644 --- a/extensions/cpu_adam/cpu_adam_x86.py +++ b/extensions/cpu_adam/cpu_adam_x86.py @@ -21,7 +21,7 @@ class CpuAdamX86Extension(_CudaExtension): # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path("cuda/cpu_adam.cpp"), + self.csrc_abs_path("x86/cpu_adam.cpp"), ] return ret diff --git a/extensions/csrc/cuda/compat.h b/extensions/csrc/cuda/compat.h deleted file mode 100644 index e69de29bb..000000000 diff --git a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu b/extensions/csrc/cuda/layer_norm_kernel.cu similarity index 100% rename from extensions/csrc/cuda/layer_norm_cuda_kernel.cu rename to extensions/csrc/cuda/layer_norm_kernel.cu diff --git a/extensions/csrc/cuda/moe_cuda_kernel.cu b/extensions/csrc/cuda/moe_kernel.cu similarity index 100% rename from extensions/csrc/cuda/moe_cuda_kernel.cu rename to extensions/csrc/cuda/moe_kernel.cu diff --git a/extensions/csrc/cuda/multi_tensor_adam.cu b/extensions/csrc/cuda/multi_tensor_adam_kernel.cu similarity index 100% rename from extensions/csrc/cuda/multi_tensor_adam.cu rename to extensions/csrc/cuda/multi_tensor_adam_kernel.cu diff --git a/extensions/csrc/cuda/multi_tensor_lamb.cu b/extensions/csrc/cuda/multi_tensor_lamb_kernel.cu similarity index 100% rename from extensions/csrc/cuda/multi_tensor_lamb.cu rename to extensions/csrc/cuda/multi_tensor_lamb_kernel.cu diff --git a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp b/extensions/csrc/cuda/pybind/inference.cpp similarity index 100% rename from extensions/csrc/cuda/colossal_inference_C_frontend.cpp rename to extensions/csrc/cuda/pybind/inference.cpp diff --git a/extensions/csrc/cuda/layer_norm_cuda.cpp b/extensions/csrc/cuda/pybind/layer_norm.cpp similarity index 100% rename from extensions/csrc/cuda/layer_norm_cuda.cpp rename to extensions/csrc/cuda/pybind/layer_norm.cpp diff --git a/extensions/csrc/cuda/moe_cuda.cpp b/extensions/csrc/cuda/pybind/moe.cpp similarity index 100% rename from extensions/csrc/cuda/moe_cuda.cpp rename to extensions/csrc/cuda/pybind/moe.cpp diff --git a/extensions/csrc/cuda/colossal_C_frontend.cpp b/extensions/csrc/cuda/pybind/optimizer.cpp similarity index 100% rename from extensions/csrc/cuda/colossal_C_frontend.cpp rename to extensions/csrc/cuda/pybind/optimizer.cpp diff --git a/extensions/csrc/cuda/scaled_masked_softmax.cpp b/extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp similarity index 100% rename from extensions/csrc/cuda/scaled_masked_softmax.cpp rename to extensions/csrc/cuda/pybind/scaled_masked_softmax.cpp diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp b/extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp similarity index 100% rename from extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp rename to extensions/csrc/cuda/pybind/scaled_upper_triang_masked_softmax.cpp diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 99d36575d..0ab40f9f7 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -9,7 +9,7 @@ #include "block_reduce.h" -#include "type_shim.h" +#include "../common/micros.h" template __global__ void rms_layernorm_kernel( diff --git a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_kernel.cu similarity index 100% rename from extensions/csrc/cuda/scaled_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_masked_softmax_kernel.cu diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu similarity index 100% rename from extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax_kernel.cu diff --git a/extensions/csrc/cuda/cpu_adam.cpp b/extensions/csrc/x86/cpu_adam.cpp similarity index 100% rename from extensions/csrc/cuda/cpu_adam.cpp rename to extensions/csrc/x86/cpu_adam.cpp diff --git a/extensions/csrc/cuda/cpu_adam.h b/extensions/csrc/x86/cpu_adam.h similarity index 100% rename from extensions/csrc/cuda/cpu_adam.h rename to extensions/csrc/x86/cpu_adam.h diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index 042c598fb..f465fe600 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -10,7 +10,7 @@ class InferenceOpsCudaExtension(_CudaExtension): ret = [ self.csrc_abs_path(fname) for fname in [ - "cuda/colossal_inference_C_frontend.cpp", + "cuda/pybind/inference.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", "cuda/activation_kernel.cu", "cuda/rms_layernorm_kernel.cu", diff --git a/extensions/layernorm/layernorm_cuda.py b/extensions/layernorm/layernorm_cuda.py index db5f2fce1..36cf73590 100644 --- a/extensions/layernorm/layernorm_cuda.py +++ b/extensions/layernorm/layernorm_cuda.py @@ -7,7 +7,7 @@ class LayerNormCudaExtension(_CudaExtension): super().__init__(name="layernorm_cuda") def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/layer_norm.cpp", "cuda/layer_norm_kernel.cu"]] return ret def include_dirs(self): diff --git a/extensions/moe/moe_cuda.py b/extensions/moe/moe_cuda.py index 52883e97f..722daae33 100644 --- a/extensions/moe/moe_cuda.py +++ b/extensions/moe/moe_cuda.py @@ -11,7 +11,7 @@ class MoeCudaExtension(_CudaExtension): return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe.cpp", "cuda/moe_kernel.cu"]] return ret def cxx_flags(self): diff --git a/extensions/optimizer/fused_optimizer_cuda.py b/extensions/optimizer/fused_optimizer_cuda.py index e065cf34a..41c6260aa 100644 --- a/extensions/optimizer/fused_optimizer_cuda.py +++ b/extensions/optimizer/fused_optimizer_cuda.py @@ -10,12 +10,12 @@ class FusedOptimizerCudaExtension(_CudaExtension): ret = [ self.csrc_abs_path(fname) for fname in [ - "cuda/colossal_C_frontend.cpp", + "cuda/pybind/optimizer.cpp", "cuda/multi_tensor_sgd_kernel.cu", "cuda/multi_tensor_scale_kernel.cu", - "cuda/multi_tensor_adam.cu", + "cuda/multi_tensor_adam_kernel.cu", "cuda/multi_tensor_l2norm_kernel.cu", - "cuda/multi_tensor_lamb.cu", + "cuda/multi_tensor_lamb_kernel.cu", ] ] return ret diff --git a/extensions/softmax/scaled_masked_softmax_cuda.py b/extensions/softmax/scaled_masked_softmax_cuda.py index 5b4208dba..797638c3b 100644 --- a/extensions/softmax/scaled_masked_softmax_cuda.py +++ b/extensions/softmax/scaled_masked_softmax_cuda.py @@ -9,7 +9,7 @@ class ScaledMaskedSoftmaxCudaExtension(_CudaExtension): def sources_files(self): ret = [ self.csrc_abs_path(fname) - for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"] + for fname in ["cuda/pybind/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_kernel.cu"] ] return ret diff --git a/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py index d4f27a921..d48d542ad 100644 --- a/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py +++ b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py @@ -13,8 +13,8 @@ class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension): ret = [ self.csrc_abs_path(fname) for fname in [ - "cuda/scaled_upper_triang_masked_softmax.cpp", - "cuda/scaled_upper_triang_masked_softmax_cuda.cu", + "cuda/pybind/scaled_upper_triang_masked_softmax.cpp", + "cuda/scaled_upper_triang_masked_softmax_kernel.cu", ] ] return ret From b699f54007c52b2f4ec56326a495b06858cf8856 Mon Sep 17 00:00:00 2001 From: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Date: Tue, 12 Mar 2024 17:48:02 +0800 Subject: [PATCH 6/9] optimize rmsnorm: add vectorized elementwise op, feat loop unrolling (#5441) --- extensions/csrc/common/cuda_type_utils.h | 122 +++++++ extensions/csrc/cuda/rms_layernorm_kernel.cu | 322 ++++++++++++++++--- 2 files changed, 406 insertions(+), 38 deletions(-) create mode 100644 extensions/csrc/common/cuda_type_utils.h diff --git a/extensions/csrc/common/cuda_type_utils.h b/extensions/csrc/common/cuda_type_utils.h new file mode 100644 index 000000000..35d4c1492 --- /dev/null +++ b/extensions/csrc/common/cuda_type_utils.h @@ -0,0 +1,122 @@ +/* + * This code from NVIDIA FasterTransformer: + * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh + */ + +#pragma once + +#include +#include + +template +inline __device__ T add(T a, T b) { + return a + b; +} + +template <> +inline __device__ half2 add(half2 a, half2 b) { + return __hadd2(a, b); +} + +template <> +inline __device__ half add(half a, half b) { + return __hadd(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { + return bf16hadd2(a, b); +} + +template <> +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { + return bf16hadd(a, b); +} + +#endif // ENABLE_BF16 + +template +inline __device__ T mul(T a, T b, T c) { + return a * b * c; +} + +template <> +inline __device__ half2 mul(half2 a, half2 b, half2 c) { + return __hmul2(__hmul2(a, b), c); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, + __nv_bfloat16 c) { + return bf16hmul(a, b, c); +} + +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { + return bf16hmul2(a, b, c); +} +#endif // ENABLE_BF16 + +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float2 cuda_cast(int2 val) { + return make_float2(val.x, val.y); +} +template <> +__device__ inline float2 cuda_cast(float val) { + return make_float2(val, val); +} +template <> +__device__ inline float2 cuda_cast(half2 val) { + return __half22float2(val); +} +template <> +__device__ inline half2 cuda_cast(float2 val) { + return __float22half2_rn(val); +} +template <> +__device__ inline half2 cuda_cast(float val) { + return __float2half2_rn(val); +} +template <> +__device__ inline half2 cuda_cast(half val) { + return __half2half2(val); +} +template <> +__device__ inline float cuda_cast(half val) { + return __half2float(val); +} + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = at::Half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +#if ENABLE_BF16 +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = at::BFloat16; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; +#endif // ENABLE_BF16 diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 0ab40f9f7..0e3e4e900 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -1,5 +1,5 @@ /*This code from VLLM: - * https://github.com/vllm-project/vllm/ + * https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu * with minor changes. */ #include @@ -10,8 +10,10 @@ #include "block_reduce.h" #include "../common/micros.h" +#include "../common/cuda_type_utils.h" -template +// optimized for half and bf16 +template __global__ void rms_layernorm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] @@ -19,8 +21,9 @@ __global__ void rms_layernorm_kernel( const float epsilon, const int num_tokens, const int hidden_size) { + using scalar2_t = typename TypeConverter::Type; __shared__ float s_variance; - float variance = 0.0f; + /* * since the open-sourced LLM's hidden dimensions mainly range from * 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported @@ -29,11 +32,22 @@ __global__ void rms_layernorm_kernel( * will cause problems for extremely large models, such as * Megatron-Turing NLG 530B with hidden dimensions up to 20480 */ - float x_local[8]; + scalar2_t x_local[4]; - for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { - x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; - variance += x_local[cnt] * x_local[cnt]; + scalar2_t* out_ptr = (scalar2_t*)out; + const scalar2_t* input_ptr = (scalar2_t*)input; + const scalar2_t* weight_ptr = (const scalar2_t*)weight; + + float variance = 0.0f; + int row_offset = blockIdx.x * hidden_size / 2; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input_ptr[id]; + float v1 = cuda_cast(x_local[cnt].x); + float v2 = cuda_cast(x_local[cnt].y); + variance += v1 * v1 + v2 * v2; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { @@ -41,16 +55,19 @@ __global__ void rms_layernorm_kernel( } __syncthreads(); - for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + scalar2_t s_variance_2 = cuda_cast(s_variance); +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]); } } -template -__global__ void fused_add_rms_layernorm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] +template +__global__ void rms_layernorm_kernel( + float* __restrict__ out, // [..., hidden_size] + const float* __restrict__ input, // [..., hidden_size] + const float* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -58,11 +75,13 @@ __global__ void fused_add_rms_layernorm_kernel( float variance = 0.0f; float x_local[8]; + int row_offset = blockIdx.x * hidden_size; + +#pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { - x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx]; - x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx]; + int id = row_offset + idx; + x_local[cnt] = input[id]; variance += x_local[cnt] * x_local[cnt]; - residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt]; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { @@ -70,8 +89,89 @@ __global__ void fused_add_rms_layernorm_kernel( } __syncthreads(); +#pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { - input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; + int id = row_offset + idx; + out[id] = ((x_local[cnt] * s_variance)) * weight[idx]; + } +} + +// optimized for half and bf16 +template +__global__ void fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + using scalar2_t = typename TypeConverter::Type; + __shared__ float s_variance; + scalar2_t x_local[4]; + + scalar2_t* input_ptr = (scalar2_t*)input; + scalar2_t* residual_ptr = (scalar2_t*)residual; + const scalar2_t* weight_ptr = (const scalar2_t*)weight; + + float variance = 0.0f; + int row_offset = blockIdx.x * hidden_size / 2; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input_ptr[id]; + x_local[cnt] = add(x_local[cnt], residual_ptr[id]); + float v1 = cuda_cast(x_local[cnt].x); + float v2 = cuda_cast(x_local[cnt].y); + variance += v1 * v1 + v2 * v2; + residual_ptr[id] = x_local[cnt]; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + scalar2_t s_variance_2 = cuda_cast(s_variance); +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]); + } +} + +template +__global__ void fused_add_rms_layernorm_kernel( + float* __restrict__ input, // [..., hidden_size] + float* __restrict__ residual, // [..., hidden_size] + const float* __restrict__ weight, // [hidden_size] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + float x_local[8]; + + int row_offset = blockIdx.x * hidden_size; + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + x_local[cnt] = input[id]; + x_local[cnt] += residual[id]; + variance += x_local[cnt] * x_local[cnt]; + residual[id] = x_local[cnt]; + } + variance = blockReduceSum(variance); + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + +#pragma unroll unroll_factor + for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { + int id = row_offset + idx; + input[id] = ((x_local[cnt] * s_variance)) * weight[idx]; } } @@ -88,16 +188,89 @@ void rms_layernorm( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_FLOAT_HALF_AND_BFLOAT( - input.scalar_type(), - "rms_layernorm_kernel", - rms_layernorm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + if (num_tokens >= 512) { + if (input.scalar_type() == at::ScalarType::Float) { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } else { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } + } else { + int unroll_factor = (hidden_size + block.x - 1) / block.x; + if (input.scalar_type() != at::ScalarType::Float) { + block.x = std::min(hidden_size / 2, 1024); + int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + } + switch (unroll_factor) { + case 1: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 2: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 4: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 8: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "rms_layernorm_kernel", + rms_layernorm_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + default: + AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); + } + } } void fused_add_rms_layernorm( @@ -113,14 +286,87 @@ void fused_add_rms_layernorm( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_FLOAT_HALF_AND_BFLOAT( - input.scalar_type(), - "fused_add_rms_layernorm_kernel", - fused_add_rms_layernorm_kernel<<>>( - input.data_ptr(), - residual.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size);) + if (num_tokens >= 512) { + if (input.scalar_type() == at::ScalarType::Float) { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } else { + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + } + } else { + int unroll_factor = (hidden_size + block.x - 1) / block.x; + if (input.scalar_type() != at::ScalarType::Float) { + block.x = std::min(hidden_size / 2, 1024); + int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + } + switch (unroll_factor) { + case 1: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 2: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 4: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + case 8: + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), + "fused_add_rms_layernorm_kernel", + fused_add_rms_layernorm_kernel<<>>( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + epsilon, + num_tokens, + hidden_size);) + break; + default: + AT_ERROR("unroll_factor must be 1, 2, 4 or 8"); + } + } } From c1c45e9d8ecb6743e88e63dd151c617c0014e7c1 Mon Sep 17 00:00:00 2001 From: xs_courtesy Date: Wed, 13 Mar 2024 11:21:06 +0800 Subject: [PATCH 7/9] fix include path --- extensions/csrc/cuda/pybind/layer_norm.cpp | 2 +- extensions/moe/moe_cuda.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions/csrc/cuda/pybind/layer_norm.cpp b/extensions/csrc/cuda/pybind/layer_norm.cpp index 3439e5e71..b1f7c2543 100644 --- a/extensions/csrc/cuda/pybind/layer_norm.cpp +++ b/extensions/csrc/cuda/pybind/layer_norm.cpp @@ -7,7 +7,7 @@ #include #include -#include "../common/micros.h" +#include "../../common/micros.h" namespace { diff --git a/extensions/moe/moe_cuda.py b/extensions/moe/moe_cuda.py index 722daae33..7a4744d4d 100644 --- a/extensions/moe/moe_cuda.py +++ b/extensions/moe/moe_cuda.py @@ -11,7 +11,7 @@ class MoeCudaExtension(_CudaExtension): return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe.cpp", "cuda/moe_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/moe.cpp", "cuda/moe_kernel.cu"]] return ret def cxx_flags(self): From ed431de4e4f73584e6b9c11ab041ef54a8e83de6 Mon Sep 17 00:00:00 2001 From: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Date: Wed, 13 Mar 2024 16:00:55 +0800 Subject: [PATCH 8/9] fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454) --- extensions/csrc/cuda/rms_layernorm_kernel.cu | 100 +++++++++++++------ tests/test_infer/test_inference_engine.py | 14 ++- 2 files changed, 79 insertions(+), 35 deletions(-) diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 0e3e4e900..8b250cb10 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -12,6 +12,34 @@ #include "../common/micros.h" #include "../common/cuda_type_utils.h" +#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ + if (DATA_SIZE == 2) { \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } else { \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + general_##__VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } \ + // optimized for half and bf16 template __global__ void rms_layernorm_kernel( @@ -63,11 +91,11 @@ __global__ void rms_layernorm_kernel( } } -template -__global__ void rms_layernorm_kernel( - float* __restrict__ out, // [..., hidden_size] - const float* __restrict__ input, // [..., hidden_size] - const float* __restrict__ weight, // [hidden_size] +template +__global__ void general_rms_layernorm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -80,7 +108,7 @@ __global__ void rms_layernorm_kernel( #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { int id = row_offset + idx; - x_local[cnt] = input[id]; + x_local[cnt] = (float) input[id]; variance += x_local[cnt] * x_local[cnt]; } variance = blockReduceSum(variance); @@ -92,7 +120,7 @@ __global__ void rms_layernorm_kernel( #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { int id = row_offset + idx; - out[id] = ((x_local[cnt] * s_variance)) * weight[idx]; + out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; } } @@ -140,11 +168,11 @@ __global__ void fused_add_rms_layernorm_kernel( } } -template -__global__ void fused_add_rms_layernorm_kernel( - float* __restrict__ input, // [..., hidden_size] - float* __restrict__ residual, // [..., hidden_size] - const float* __restrict__ weight, // [hidden_size] +template +__global__ void general_fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -157,10 +185,10 @@ __global__ void fused_add_rms_layernorm_kernel( #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { int id = row_offset + idx; - x_local[cnt] = input[id]; - x_local[cnt] += residual[id]; + x_local[cnt] = (float) input[id]; + x_local[cnt] += (float) residual[id]; variance += x_local[cnt] * x_local[cnt]; - residual[id] = x_local[cnt]; + residual[id] = (scalar_t) x_local[cnt]; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { @@ -171,7 +199,7 @@ __global__ void fused_add_rms_layernorm_kernel( #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { int id = row_offset + idx; - input[id] = ((x_local[cnt] * s_variance)) * weight[idx]; + input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; } } @@ -190,7 +218,8 @@ void rms_layernorm( if (num_tokens >= 512) { if (input.scalar_type() == at::ScalarType::Float) { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -201,7 +230,8 @@ void rms_layernorm( num_tokens, hidden_size);) } else { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -216,11 +246,12 @@ void rms_layernorm( int unroll_factor = (hidden_size + block.x - 1) / block.x; if (input.scalar_type() != at::ScalarType::Float) { block.x = std::min(hidden_size / 2, 1024); - int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; } switch (unroll_factor) { case 1: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -232,7 +263,8 @@ void rms_layernorm( hidden_size);) break; case 2: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -244,7 +276,8 @@ void rms_layernorm( hidden_size);) break; case 4: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -256,7 +289,8 @@ void rms_layernorm( hidden_size);) break; case 8: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -288,7 +322,8 @@ void fused_add_rms_layernorm( if (num_tokens >= 512) { if (input.scalar_type() == at::ScalarType::Float) { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -299,7 +334,8 @@ void fused_add_rms_layernorm( num_tokens, hidden_size);) } else { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -314,11 +350,12 @@ void fused_add_rms_layernorm( int unroll_factor = (hidden_size + block.x - 1) / block.x; if (input.scalar_type() != at::ScalarType::Float) { block.x = std::min(hidden_size / 2, 1024); - int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; } switch (unroll_factor) { case 1: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -330,7 +367,8 @@ void fused_add_rms_layernorm( hidden_size);) break; case 2: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -342,7 +380,8 @@ void fused_add_rms_layernorm( hidden_size);) break; case 4: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -354,7 +393,8 @@ void fused_add_rms_layernorm( hidden_size);) break; case 8: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 25b2c2f43..edd92bb96 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -22,11 +22,15 @@ def setup_seed(seed): def check_inference_engine(use_engine=False, prompt_template=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - model = LlamaForCausalLM( - LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + model = ( + LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + ) ) - ).cuda() + .cuda() + .half() + ) model = model.eval() inputs = [ @@ -40,7 +44,7 @@ def check_inference_engine(use_engine=False, prompt_template=None): top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") + inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) From f366a5ea1f2626a7870acaf8866f21d5fb49c388 Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Wed, 13 Mar 2024 17:20:03 +0800 Subject: [PATCH 9/9] [Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418) * add rotary embedding kernel * add rotary_embedding_kernel * add fused rotary_emb and kvcache memcopy * add fused_rotary_emb_and_cache_kernel.cu * add fused_rotary_emb_and_memcopy * fix bugs in fused_rotary_emb_and_cache_kernel.cu * fix ci bugs * use vec memcopy and opt the gloabl memory access * fix code style * fix test_rotary_embdding_unpad.py * codes revised based on the review comments * fix bugs about include path * rm inline --- .../modeling/models/nopadding_llama.py | 19 +- colossalai/inference/utils.py | 4 +- ... benchmark_fused_rotary_embdding_unpad.py} | 34 +- ...dding.py => benchmark_rotary_embedding.py} | 29 +- .../benchmark_ops/benchmark_xine_copy.py | 54 ++ extensions/csrc/common/vector_copy_utils.h | 98 ++++ extensions/csrc/cuda/activation_kernel.cu | 3 + .../cuda/decode_kv_cache_memcpy_kernel.cu | 163 ++++-- .../cuda/fused_rotary_emb_and_cache_kernel.cu | 472 ++++++++++++++++++ extensions/csrc/cuda/pybind/inference.cpp | 24 + extensions/inference/inference_ops_cuda.py | 1 + tests/test_infer/test_inference_engine.py | 14 +- .../cuda/test_rotary_embdding_unpad.py | 91 ++++ 13 files changed, 928 insertions(+), 78 deletions(-) rename examples/inference/benchmark_ops/{benchmark_rotary_embdding_unpad.py => benchmark_fused_rotary_embdding_unpad.py} (70%) rename examples/inference/benchmark_ops/{benchmark_fused_rotary_embedding.py => benchmark_rotary_embedding.py} (62%) create mode 100644 examples/inference/benchmark_ops/benchmark_xine_copy.py create mode 100644 extensions/csrc/common/vector_copy_utils.h create mode 100644 extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu create mode 100644 tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index f84abab4b..12de4802b 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -320,8 +320,12 @@ class NopadLlamaAttention(LlamaAttention): ) block_size = k_cache.size(-2) + if is_prompts: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + if use_cuda_kernel: + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + else: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) attn_output = context_attention_unpadded( q=query_states, k=key_states, @@ -337,9 +341,16 @@ class NopadLlamaAttention(LlamaAttention): ) else: if use_cuda_kernel: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - inference_ops.decode_kv_cache_memcpy( - key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables + inference_ops.rotary_embedding_and_cache_copy( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + sequence_lengths, + block_tables, ) else: decoding_fused_rotary_embedding( diff --git a/colossalai/inference/utils.py b/colossalai/inference/utils.py index 990864813..a97b9c9d6 100644 --- a/colossalai/inference/utils.py +++ b/colossalai/inference/utils.py @@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() - self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + self._cos_cached = torch.cos(freqs).to(self.dtype).cuda() + self._sin_cached = torch.sin(freqs).to(self.dtype).cuda() diff --git a/examples/inference/benchmark_ops/benchmark_rotary_embdding_unpad.py b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py similarity index 70% rename from examples/inference/benchmark_ops/benchmark_rotary_embdding_unpad.py rename to examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py index 0e22ed7d2..f11630dff 100644 --- a/examples/inference/benchmark_ops/benchmark_rotary_embdding_unpad.py +++ b/examples/inference/benchmark_ops/benchmark_fused_rotary_embdding_unpad.py @@ -1,8 +1,11 @@ import torch +from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token +inference_ops = InferenceOpsLoader().load() + try: import triton # noqa @@ -16,9 +19,19 @@ configs = [ x_names=["num_tokens"], x_vals=[2**i for i in range(4, 11)], line_arg="provider", - line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], - line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], - styles=[("red", "-"), ("blue", "-")], + line_vals=[ + "no_fused_triton_rotary_emb_func", + "fused_triton_rotary_emb_func", + "no_fused_cuda_rotary_emb_func", + "fused_cuda_rotary_emb_func", + ], + line_names=[ + "no_fused_triton_rotary_emb_func", + "fused_triton_rotary_emb_func", + "no_fused_cuda_rotary_emb_func", + "fused_cuda_rotary_emb_func", + ], + styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")], ylabel="ms", plot_name=f"rotary_emb-batch-{BATCH}", args={"num_kv_heads": 16}, @@ -32,7 +45,7 @@ def benchmark_rotary_emb( num_tokens: int, num_kv_heads: int, ): - BATCH_SIZE = 4 + BATCH_SIZE = 16 SEQ_LEN = num_tokens // BATCH_SIZE max_num_blocks_per_seq = 8 block_size = 64 @@ -68,7 +81,7 @@ def benchmark_rotary_emb( kv_seq_lengths = past_kv_seq_lengths + 1 block_tables = block_tables.to(device="cuda") - if provider == "no_fused_rotary_emb_func": + if provider == "no_fused_triton_rotary_emb_func": fn = lambda: [ rotary_embedding(new_q, new_k, cos, sin), copy_kv_to_blocked_cache( @@ -77,7 +90,16 @@ def benchmark_rotary_emb( ] elif provider == "fused_triton_rotary_emb_func": fn = lambda: decoding_fused_rotary_embedding( - new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths + new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths + ) + elif provider == "no_fused_cuda_rotary_emb_func": + fn = lambda: [ + inference_ops.rotary_embedding(new_q, new_k, cos, sin), + inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables), + ] + elif provider == "fused_cuda_rotary_emb_func": + fn = lambda: inference_ops.rotary_embedding_and_cache_copy( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables ) else: raise ValueError("Undefined provider") diff --git a/examples/inference/benchmark_ops/benchmark_fused_rotary_embedding.py b/examples/inference/benchmark_ops/benchmark_rotary_embedding.py similarity index 62% rename from examples/inference/benchmark_ops/benchmark_fused_rotary_embedding.py rename to examples/inference/benchmark_ops/benchmark_rotary_embedding.py index 9b44ef791..97cf2e0b2 100644 --- a/examples/inference/benchmark_ops/benchmark_fused_rotary_embedding.py +++ b/examples/inference/benchmark_ops/benchmark_rotary_embedding.py @@ -1,7 +1,11 @@ import torch import triton +from vllm._C import ops -from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import rotary_embedding + +inference_ops = InferenceOpsLoader().load() BATCH = 16 configs = [ @@ -9,9 +13,9 @@ configs = [ x_names=["num_tokens"], x_vals=[2**i for i in range(4, 12)], line_arg="provider", - line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], - line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], - styles=[("red", "-"), ("blue", "-")], + line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"], + line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"], + styles=[("red", "-"), ("blue", "-"), ("yellow", "-")], ylabel="ms", plot_name=f"rotary_emb-batch-{BATCH}", args={"num_kv_heads": 16}, @@ -48,12 +52,19 @@ def benchmark_rotary_emb( cos_shape = (4096, head_dim // 2) cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - lengths = torch.tensor([3, 4, 6, 7], device="cuda") - if provider == "torch_rotary_emb_func": - fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens]) - elif provider == "triton_rotary_emb_func": - fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths) + cos_sin = torch.stack((cos, sin), dim=1).contiguous() + + positions = torch.arange(num_tokens).cuda() + + if provider == "triton_func": + fn = lambda: rotary_embedding(q, k, cos, sin) + elif provider == "colossal_cuda_func": + fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin) + elif provider == "vllm_cuda_func": + q = q.view(num_tokens, -1) + k = k.view(num_tokens, -1) + fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True) else: raise ValueError("Undefined provider") diff --git a/examples/inference/benchmark_ops/benchmark_xine_copy.py b/examples/inference/benchmark_ops/benchmark_xine_copy.py new file mode 100644 index 000000000..b15232b91 --- /dev/null +++ b/examples/inference/benchmark_ops/benchmark_xine_copy.py @@ -0,0 +1,54 @@ +import torch + +from colossalai.kernel.triton import get_xine_cache +from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin + +try: + import triton # noqa + +except ImportError: + print("please install triton from https://github.com/openai/triton") + + +configs = [ + triton.testing.Benchmark( + x_names=["max_num_tokens"], + x_vals=[2**i for i in range(6, 12)], + line_arg="provider", + line_vals=["torch_get_cos_sin", "triton_get_cos_sin"], + line_names=["torch_get_cos_sin", "triton_get_cos_sin"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name="Get_cos-sin_func", + args={"batch_size": 16, "head_dim": 256}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_get_xine_cache( + provider: str, + max_num_tokens: int, + batch_size: int, + head_dim: int, +): + warmup = 10 + rep = 1000 + dtype = torch.float16 + cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") + sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda") + lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda") + + if provider == "torch_get_cos_sin": + fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype) + elif provider == "triton_get_cos_sin": + fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True) + else: + raise ValueError("Undefined provider") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + return ms + + +if __name__ == "__main__": + benchmark_get_xine_cache.run(save_path=".", print_data=True) diff --git a/extensions/csrc/common/vector_copy_utils.h b/extensions/csrc/common/vector_copy_utils.h new file mode 100644 index 000000000..456440cf6 --- /dev/null +++ b/extensions/csrc/common/vector_copy_utils.h @@ -0,0 +1,98 @@ + +#include +#include + +#include + +#include "string" + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float *)dst) = *((float *)src); +} + +template <> +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} + +template <> +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float4 *)dst) = *((float4 *)src); +} + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float *)dst) = *((float *)src); +} + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float4 *)dst) = *((float4 *)src); +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + *((float2 *)dst) = *((float2 *)src); +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + *((float4 *)dst) = *((float4 *)src); +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + // Since the maximum memory alignment length is 128 bits, we choose float4 + // here. + *((float4 *)dst) = *((float4 *)src); + *((float4 *)(dst + 4)) = *((float4 *)(src + 4)); +} + +template +int get_vec_size(const torch::Tensor &tensor) { + uint64_t address = reinterpret_cast(tensor.data_ptr()); + const int max_aligned_size = 128; + const int dtype_size = sizeof(T) * 8; + + const int vec_size = max_aligned_size / sizeof(T) / 8; + + if (address % (dtype_size * 4) == 0) { + return std::min(4, vec_size); + } else if (address % (dtype_size * 2) == 0) { + return std::min(2, vec_size); + } else { + return 1; + } +} diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index 5213a2313..e9dc01753 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -39,6 +39,9 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins) auto ins_shape = ins.sizes().vec(); ins_shape[0] = ins_shape[0]/2; + if (ins_shape[0] == 1) { + ins_shape.erase(ins_shape.begin()); + } auto outs = torch::zeros(ins_shape,ins.options()); auto outs_shape = ins.sizes().vec(); diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 15e613e35..7eb44ecd0 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -1,10 +1,10 @@ #include #include -#include +#include "../common/vector_copy_utils.h" #include "../common/micros.h" -template +template __global__ void decode_kv_cache_memcpy_kernel( const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, @@ -12,79 +12,146 @@ __global__ void decode_kv_cache_memcpy_kernel( scalar_t* __restrict__ value_cache, const int* __restrict__ sequence_lengths, const int* __restrict__ block_tables, - const int num_heads, - const int head_size, + const int head_num, + const int head_dim, const int block_size, - const int key_stride, - const int value_stride, + const int64_t key_stride, + const int64_t value_stride, const int block_table_stride ) { const int seq_id = blockIdx.x; const int seq_len = sequence_lengths[seq_id] - 1; - const int seq_id_in_block_table = seq_len / block_size; const int block_offset = seq_len % block_size; - const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table]; - const int hidden_size = num_heads * head_size; + const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size]; + const int hidden_size = head_num * head_dim; if ( block_id < 0 ) { return ; } - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - const int head_id = i / head_size; - const int head_offset = i % head_size; - const int key_src_id = seq_id * key_stride + i; - const int value_src_id = seq_id * value_stride + i; - const int target_src_id = block_id * hidden_size * block_size - + head_id * block_size * head_size - + block_offset * head_size + head_offset; + for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; - key_cache[target_src_id] = key[key_src_id]; - value_cache[target_src_id] = value[value_src_id]; + copy_vector(key_cache + target_id, key + key_src_id); + copy_vector(value_cache + target_id, value + value_src_id); } } -void decode_kv_cache_memcpy( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size] - torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size] - torch::Tensor& sequence_lengths, // [batch_size] - torch::Tensor& block_tables) // [batch_size, max_seq_len] +template +void apply_decode_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] { int num_tokens = key.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); + int head_num = key.size(1); + int head_dim = key.size(2); int block_size = key_cache.size(2); - int key_stride = key.stride(0); - int value_stride = value.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); int block_table_stride = block_tables.stride(0); + int vec_size = get_vec_size(key); + + if (head_dim % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + int thread_nums = head_num * head_dim / vec_size; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - DISPATCH_FLOAT_HALF_AND_BFLOAT( - key.scalar_type(), - "decode_kv_cache_memcpy", - decode_kv_cache_memcpy_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - sequence_lengths.data_ptr(), - block_tables.data_ptr(), - num_heads, - head_size, - block_size, - key_stride, - value_stride, - block_table_stride - );) + dim3 block(std::min(thread_nums, 512)); + + switch (vec_size) { + case 1: + decode_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + key_stride, + value_stride, + block_table_stride + ); + break; + case 2: + decode_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + key_stride, + value_stride, + block_table_stride + ); + break; + case 4: + decode_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + key_stride, + value_stride, + block_table_stride + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } AT_CUDA_CHECK(cudaGetLastError()); } + +void decode_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + key.scalar_type(), + "decode_kv_cache_memcpy", + apply_decode_kv_cache_memcpy( + key, + value, + key_cache, + value_cache, + sequence_lengths, + block_tables + );) +} diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu new file mode 100644 index 000000000..c1db06d3f --- /dev/null +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -0,0 +1,472 @@ + +#include +#include + +#include "../common/vector_copy_utils.h" +#include "../common/micros.h" + +template +__device__ void apply_emb_rotary_compute( + scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, const int64_t stride, + const int token_id, const int shard_block_size, const int half_head_dim, + const int head_num, const int head_dim) { + scalar_t x[VecSize]; + scalar_t y[VecSize]; + scalar_t out_x[VecSize]; + scalar_t out_y[VecSize]; + + for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim; + i += blockDim.x * VecSize) { + const int head_offset = i % half_head_dim; + const int shard_offset = + (head_offset / shard_block_size) * shard_block_size + + (head_offset % shard_block_size) / VecSize; + const int64_t addr_offset = + token_id * stride + (i / half_head_dim) * head_dim + head_offset; + + copy_vector(x, src + addr_offset); + copy_vector(y, src + addr_offset + half_head_dim); + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - + y[j] * sin_ptr[j * 32 + shard_offset]; + out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + + x[j] * sin_ptr[j * 32 + shard_offset]; + } + + copy_vector(src + addr_offset, out_x); + copy_vector(src + addr_offset + half_head_dim, out_y); + } +} + +template +__device__ void apply_kv_memcopy( + scalar_t* __restrict__ src, scalar_t* __restrict__ cache, + const int64_t stride, const int token_id, const int block_id, + const int hidden_size, const int block_size, const int block_offset, + const int head_dim, const int half_head_dim) { + for (int i = threadIdx.x * VecSize; i < hidden_size / 2; + i += blockDim.x * VecSize) { + const int head_id = i / half_head_dim; + const int head_offset = i % half_head_dim; + const int64_t src_id = token_id * stride + head_id * head_dim + head_offset; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(cache + target_id, src + src_id); + copy_vector(cache + target_id + half_head_dim, + src + src_id + half_head_dim); + } +} + +template +__device__ void cos_sin_memory_access( + const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin, + scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id, + const int shard_block_size, const int cos_stride, const int sin_stride, + const int half_head_dim) { + for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { + // We assume that the value of head_dim is less than 128*128. + const int shard_offset = (i % shard_block_size) / VecSize; + const int shard_head = + (i / shard_block_size) * shard_block_size + i % VecSize * 32; + cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i]; + sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i]; + } +} + +template +__device__ void apply_k_rotary_emb_compute( + scalar_t* __restrict__ key, scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, const int64_t key_stride, + const int64_t value_stride, const int token_id, + const int block_table_stride, const int head_num, const int head_dim, + const int kv_head_num, const int block_size, const int half_head_dim, + const int shard_block_size) { + const int seq_len = sequence_lengths[token_id] - 1; + const int block_offset = seq_len % block_size; + const int block_id = + block_tables[token_id * block_table_stride + seq_len / block_size]; + + if (block_id < 0) { + return; + } + + scalar_t x[VecSize]; + scalar_t y[VecSize]; + scalar_t out_x[VecSize]; + scalar_t out_y[VecSize]; + + for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim; + i += blockDim.x * VecSize) { + const int head_offset = i % half_head_dim; + const int shard_offset = + (head_offset / shard_block_size) * shard_block_size + + (head_offset % shard_block_size) / VecSize; + const int64_t addr_offset = + token_id * key_stride + (i / half_head_dim) * head_dim + head_offset; + const int64_t target_id = block_id * head_num * head_dim * block_size + + (i / half_head_dim) * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(x, key + addr_offset); + copy_vector(y, key + addr_offset + half_head_dim); + +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - + y[j] * sin_ptr[j * 32 + shard_offset]; + out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + + x[j] * sin_ptr[j * 32 + shard_offset]; + } + + copy_vector(key_cache + target_id, out_x); + copy_vector(key_cache + target_id + half_head_dim, + out_y); + } + + // apply value memcopy + apply_kv_memcopy( + value, value_cache, value_stride, token_id, block_id, head_num * head_dim, + block_size, block_offset, head_dim, half_head_dim); +} + +template +__global__ void rotary_embedding_and_cache_copy_kernel( + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + scalar_t* __restrict__ value, + const scalar_t* __restrict__ cos, + const scalar_t* __restrict__ sin, + scalar_t* __restrict__ key_cache, + scalar_t* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ block_tables, + const int64_t query_stride, + const int64_t key_stride, + const int64_t value_stride, + const int64_t half_shard_element_num, + const int cos_stride, + const int sin_stride, + const int block_table_stride, + const int head_num, + const int head_dim, + const int kv_head_num, + const int block_size +) { + + const int token_id = blockIdx.x; + const int half_head_dim = head_dim / 2; + const int shard_block_size = VecSize * 32; + + extern __shared__ char shard_ptr[]; + + scalar_t *cos_ptr = (scalar_t*)shard_ptr; + scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + + // apply cos_sin memcopy + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + __syncthreads(); + + //compute query + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + + //compute key and copy kv + apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); +} + +template +__global__ void rotary_embedding_kernel( + scalar_t* __restrict__ query, + scalar_t* __restrict__ key, + const scalar_t* __restrict__ cos, + const scalar_t* __restrict__ sin, + const int64_t query_stride, + const int64_t key_stride, + const int64_t half_shard_element_num, + const int cos_stride, + const int sin_stride, + const int head_num, + const int head_dim, + const int kv_head_num +) { + const int token_id = blockIdx.x; + const int half_head_dim = head_dim / 2; + const int shard_block_size = VecSize * 32; + + extern __shared__ char shard_ptr[]; + + scalar_t *cos_ptr = (scalar_t*)shard_ptr; + scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + + // apply cos_sin memcopy + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + __syncthreads(); + + //compute query + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + + //compute key + apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); +} + +template +void apply_rotary_embedding_and_cache_copy( + at::Tensor& query, // [num_tokens, head_num, head_dim] + at::Tensor& key, // [num_tokens, kv_head_num, head_dim] + at::Tensor& value, // [num_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + int num_tokens = query.size(0); + int head_num = query.size(1); + int head_dim = query.size(2); + int kv_head_num = key.size(1); + int block_size = key_cache.size(2); + + int64_t query_stride = query.stride(0); + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int cos_stride = cos.stride(0); + int sin_stride = sin.stride(0); + int block_table_stride = block_tables.stride(0); + + int vec_size = get_vec_size(query); + + if ((head_dim / 2) % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int thread_nums = head_num * head_dim / vec_size / 2; + const int shard_block_size = vec_size * 32 * 2; + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ; + + switch (vec_size) { + case 1: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + case 2: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + case 4: + rotary_embedding_and_cache_copy_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + value.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + block_tables.data_ptr(), + query_stride, + key_stride, + value_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + block_table_stride, + head_num, + head_dim, + kv_head_num, + block_size + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + + AT_CUDA_CHECK(cudaGetLastError()); +} + +template +void apply_rotary_embedding( + at::Tensor& query, // [total_tokens, head_num, head_dim] + at::Tensor& key, // [total_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [total_tokens, head_dim] + at::Tensor& sin // [total_tokens, head_dim] +){ + int num_tokens = query.size(0); + int head_num = query.size(1); + int head_dim = query.size(2); + int kv_head_num = key.size(1); + + int query_stride = query.stride(0); + int key_stride = key.stride(0); + int cos_stride = cos.stride(0); + int sin_stride = sin.stride(0); + + int vec_size = get_vec_size(query); + + if ((head_dim / 2) % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int thread_nums = head_num * head_dim / vec_size / 2; + const int shard_block_size = vec_size * 32 * 2; + + dim3 grid(num_tokens); + dim3 block(std::min(thread_nums, 512)); + int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ; + + switch (vec_size) { + case 1: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + case 2: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + case 4: + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), + query_stride, + key_stride, + shard_element_num / 2, + cos_stride, + sin_stride, + head_num, + head_dim, + kv_head_num + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + AT_CUDA_CHECK(cudaGetLastError()); +} + +void rotary_embedding_and_cache_copy( + at::Tensor& query, // [num_tokens, head_num, head_dim] + at::Tensor& key, // [num_tokens, kv_head_num, head_dim] + at::Tensor& value, // [num_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [num_tokens, head_dim] + at::Tensor& sin, // [num_tokens, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& block_tables) // [batch_size, max_seq_len] +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + query.scalar_type(), + "rotary_embedding_and_cache_copy", + apply_rotary_embedding_and_cache_copy( + query, + key, + value, + cos, + sin, + key_cache, + value_cache, + sequence_lengths, + block_tables + );) +} + +void rotary_embedding( + at::Tensor& query, // [total_tokens, head_num, head_dim] + at::Tensor& key, // [total_tokens, kv_head_num, head_dim] + at::Tensor& cos, // [total_tokens, head_dim] + at::Tensor& sin // [total_tokens, head_dim] +){ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + query.scalar_type(), + "rotary_embedding", + apply_rotary_embedding( + query, + key, + cos, + sin + );) +} diff --git a/extensions/csrc/cuda/pybind/inference.cpp b/extensions/csrc/cuda/pybind/inference.cpp index 73ed49e6c..4282f5382 100644 --- a/extensions/csrc/cuda/pybind/inference.cpp +++ b/extensions/csrc/cuda/pybind/inference.cpp @@ -9,6 +9,23 @@ void decode_kv_cache_memcpy( torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& block_tables); // [batch_size, max_seq_len] +void rotary_embedding( + torch::Tensor& query, // [total_tokens, head_num, head_dim] + torch::Tensor& key, // [total_tokens, kv_head_num, head_dim] + torch::Tensor& cos, // [total_tokens, head_dim] + torch::Tensor& sin); // [total_tokens, head_dim] + +void rotary_embedding_and_cache_copy( + torch::Tensor& query, // [num_tokens, head_num, head_dim] + torch::Tensor& key, // [num_tokens, kv_head_num, head_dim] + torch::Tensor& value, // [num_tokens, num_heads, head_dim] + torch::Tensor& cos, // [num_tokens, head_dim] + torch::Tensor& sin, // [num_tokens, head_dim] + torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor& + value_cache, // [num_blocks, num_heads, block_size, head_dim] + torch::Tensor& sequence_lengths, // [batch_size] + torch::Tensor& block_tables); // [batch_size, max_seq_len] torch::Tensor silu_and_mul(const torch::Tensor& ins); void rms_layernorm(torch::Tensor& out, // [..., hidden_size] @@ -25,6 +42,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); + m.def( + "rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy, + "performing Rotary Embedding-related calculations and KVCache Memcopy."); + + m.def("rotary_embedding", &rotary_embedding, + "performing Rotary Embedding-related calculations."); + m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); m.def("rms_layernorm", &rms_layernorm, diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index f465fe600..ae3754ca7 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension): for fname in [ "cuda/pybind/inference.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", + "cuda/fused_rotary_emb_and_cache_kernel.cu", "cuda/activation_kernel.cu", "cuda/rms_layernorm_kernel.cu", ] diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index edd92bb96..25b2c2f43 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -22,15 +22,11 @@ def setup_seed(seed): def check_inference_engine(use_engine=False, prompt_template=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - model = ( - LlamaForCausalLM( - LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 - ) + model = LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 ) - .cuda() - .half() - ) + ).cuda() model = model.eval() inputs = [ @@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None): top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template) + inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) diff --git a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py new file mode 100644 index 000000000..b9c0a3269 --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py @@ -0,0 +1,91 @@ +import pytest +import torch +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb + +from colossalai.kernel.kernel_loader import InferenceOpsLoader + +inference_ops = InferenceOpsLoader().load() + +from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 +from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb + + +@pytest.mark.parametrize("BATCH_SIZE", [4]) +@pytest.mark.parametrize("SEQ_LEN", [64]) +@pytest.mark.parametrize("H", [32]) +@pytest.mark.parametrize("D", [64]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): + torch.manual_seed(10) + TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN + # our crafted op equals to Transformers + x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype) + emb = LlamaRotaryEmbedding(D) + cos, sin = emb(x0, TOTAL_TOKENS) + cos_2 = cos[:, : D // 2] + sin_2 = sin[:, : D // 2] + position_ids = torch.arange(TOTAL_TOKENS) + embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids) + embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2) + assert torch.allclose(embd_x0, embd_stimulated_x) + + # create data + block_size = 32 + max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size + q_shape = (TOTAL_TOKENS, H, D) + q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") + k_shape = (TOTAL_TOKENS, H, D) + k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") + cos_shape = (TOTAL_TOKENS, D // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + v = torch.randn_like(k) + v_cache = torch.zeros_like(k_cache) + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v2( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size + ) + new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") + new_q = torch.randn_like(new_k) + new_v = torch.randn_like(new_k) + + kv_seq_lengths = past_kv_seq_lengths + 1 + block_tables = block_tables.to(device="cuda") + q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + + new_q_copy = new_q.clone() + new_k_copy = new_k.clone() + + inference_ops.rotary_embedding_and_cache_copy( + new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables + ) + + inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin) + + past_kv_seq_len = kv_seq_lengths - 1 + target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] + offsets_in_block = past_kv_seq_len % block_size + k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze() + k_source = new_k_copy.squeeze() + v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze() + v_source = new_v.squeeze() + + assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6) + assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6) + + assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6) + assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6) + + assert k_target.shape == k_source.shape + assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6) + + assert v_target.shape == v_source.shape + assert torch.equal(v_target, v_source) + + +if __name__ == "__main__": + test_rotary_emb(16, 512, 4, 128, torch.float16)