mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454)
This commit is contained in:
@@ -12,6 +12,34 @@
|
||||
#include "../common/micros.h"
|
||||
#include "../common/cuda_type_utils.h"
|
||||
|
||||
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
|
||||
if (DATA_SIZE == 2) { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
} \
|
||||
} else { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t = float; \
|
||||
general_##__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
} \
|
||||
} \
|
||||
|
||||
// optimized for half and bf16
|
||||
template<typename scalar_t, int unroll_factor>
|
||||
__global__ void rms_layernorm_kernel(
|
||||
@@ -63,11 +91,11 @@ __global__ void rms_layernorm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template<int unroll_factor>
|
||||
__global__ void rms_layernorm_kernel(
|
||||
float* __restrict__ out, // [..., hidden_size]
|
||||
const float* __restrict__ input, // [..., hidden_size]
|
||||
const float* __restrict__ weight, // [hidden_size]
|
||||
template<typename scalar_t, int unroll_factor>
|
||||
__global__ void general_rms_layernorm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
@@ -80,7 +108,7 @@ __global__ void rms_layernorm_kernel(
|
||||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
x_local[cnt] = input[id];
|
||||
x_local[cnt] = (float) input[id];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
@@ -92,7 +120,7 @@ __global__ void rms_layernorm_kernel(
|
||||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
out[id] = ((x_local[cnt] * s_variance)) * weight[idx];
|
||||
out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,11 +168,11 @@ __global__ void fused_add_rms_layernorm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template<int unroll_factor>
|
||||
__global__ void fused_add_rms_layernorm_kernel(
|
||||
float* __restrict__ input, // [..., hidden_size]
|
||||
float* __restrict__ residual, // [..., hidden_size]
|
||||
const float* __restrict__ weight, // [hidden_size]
|
||||
template<typename scalar_t, int unroll_factor>
|
||||
__global__ void general_fused_add_rms_layernorm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
@@ -157,10 +185,10 @@ __global__ void fused_add_rms_layernorm_kernel(
|
||||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
x_local[cnt] = input[id];
|
||||
x_local[cnt] += residual[id];
|
||||
x_local[cnt] = (float) input[id];
|
||||
x_local[cnt] += (float) residual[id];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
residual[id] = x_local[cnt];
|
||||
residual[id] = (scalar_t) x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -171,7 +199,7 @@ __global__ void fused_add_rms_layernorm_kernel(
|
||||
#pragma unroll unroll_factor
|
||||
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||
int id = row_offset + idx;
|
||||
input[id] = ((x_local[cnt] * s_variance)) * weight[idx];
|
||||
input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +218,8 @@ void rms_layernorm(
|
||||
|
||||
if (num_tokens >= 512) {
|
||||
if (input.scalar_type() == at::ScalarType::Float) {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
@@ -201,7 +230,8 @@ void rms_layernorm(
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
} else {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
@@ -216,11 +246,12 @@ void rms_layernorm(
|
||||
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||
if (input.scalar_type() != at::ScalarType::Float) {
|
||||
block.x = std::min(hidden_size / 2, 1024);
|
||||
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||
}
|
||||
switch (unroll_factor) {
|
||||
case 1:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
@@ -232,7 +263,8 @@ void rms_layernorm(
|
||||
hidden_size);)
|
||||
break;
|
||||
case 2:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
@@ -244,7 +276,8 @@ void rms_layernorm(
|
||||
hidden_size);)
|
||||
break;
|
||||
case 4:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
@@ -256,7 +289,8 @@ void rms_layernorm(
|
||||
hidden_size);)
|
||||
break;
|
||||
case 8:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"rms_layernorm_kernel",
|
||||
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||
@@ -288,7 +322,8 @@ void fused_add_rms_layernorm(
|
||||
|
||||
if (num_tokens >= 512) {
|
||||
if (input.scalar_type() == at::ScalarType::Float) {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
@@ -299,7 +334,8 @@ void fused_add_rms_layernorm(
|
||||
num_tokens,
|
||||
hidden_size);)
|
||||
} else {
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||
@@ -314,11 +350,12 @@ void fused_add_rms_layernorm(
|
||||
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||
if (input.scalar_type() != at::ScalarType::Float) {
|
||||
block.x = std::min(hidden_size / 2, 1024);
|
||||
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||
}
|
||||
switch (unroll_factor) {
|
||||
case 1:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
@@ -330,7 +367,8 @@ void fused_add_rms_layernorm(
|
||||
hidden_size);)
|
||||
break;
|
||||
case 2:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
@@ -342,7 +380,8 @@ void fused_add_rms_layernorm(
|
||||
hidden_size);)
|
||||
break;
|
||||
case 4:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
@@ -354,7 +393,8 @@ void fused_add_rms_layernorm(
|
||||
hidden_size);)
|
||||
break;
|
||||
case 8:
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||
input.element_size(),
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_layernorm_kernel",
|
||||
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||
|
Reference in New Issue
Block a user