mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)
This commit is contained in:
@@ -11,8 +11,25 @@ void decode_kv_cache_memcpy(
|
||||
|
||||
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||
|
||||
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
|
||||
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||
|
||||
m.def("rms_layernorm", &rms_layernorm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
|
||||
"In-place fused Add and RMS Normalization.");
|
||||
}
|
||||
|
126
extensions/csrc/cuda/rms_layernorm_kernel.cu
Normal file
126
extensions/csrc/cuda/rms_layernorm_kernel.cu
Normal file
@@ -0,0 +1,126 @@
|
||||
/*This code from VLLM:
|
||||
* https://github.com/vllm-project/vllm/
|
||||
* with minor changes. */
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
||||
#include "block_reduce.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void rms_layernorm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
/*
|
||||
* since the open-sourced LLM's hidden dimensions mainly range from
|
||||
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
|
||||
* hidden dimension limit to 8192, and each thread's capacity
|
||||
* for caching input tensors to 8 (8192 = 8 * 1024) which
|
||||
* will cause problems for extremely large models, such as
|
||||
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
|
||||
*/
|
||||
float x_local[8];
|
||||
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void fused_add_rms_layernorm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
float x_local[8];
|
||||
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
x_local[cnt] = (float) input[blockIdx.x * hidden_size + idx];
|
||||
x_local[cnt] += (float) residual[blockIdx.x * hidden_size + idx];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
void rms_layernorm(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
}
|
||||
|
||||
void fused_add_rms_layernorm(
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
}
|
@@ -13,12 +13,13 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
||||
"cuda/colossal_inference_C_frontend.cpp",
|
||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||
"cuda/activation_kernel.cu",
|
||||
"cuda/rms_layernorm_kernel.cu",
|
||||
]
|
||||
]
|
||||
return ret
|
||||
|
||||
def include_dirs(self):
|
||||
ret = [self.get_cuda_home_include()]
|
||||
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
|
||||
return ret
|
||||
|
||||
def cxx_flags(self):
|
||||
|
Reference in New Issue
Block a user