fix rmsnorm template function invocation problem(template function partial specialization is not allowed in Cpp) and luckily pass e2e precision test (#5454)

This commit is contained in:
Steve Luo
2024-03-13 16:00:55 +08:00
committed by GitHub
parent 6fd355a5a6
commit ed431de4e4
2 changed files with 79 additions and 35 deletions

View File

@@ -12,6 +12,34 @@
#include "../common/micros.h"
#include "../common/cuda_type_utils.h"
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
if (DATA_SIZE == 2) { \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
} else { \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
general_##__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
} \
// optimized for half and bf16
template<typename scalar_t, int unroll_factor>
__global__ void rms_layernorm_kernel(
@@ -63,11 +91,11 @@ __global__ void rms_layernorm_kernel(
}
}
template<int unroll_factor>
__global__ void rms_layernorm_kernel(
float* __restrict__ out, // [..., hidden_size]
const float* __restrict__ input, // [..., hidden_size]
const float* __restrict__ weight, // [hidden_size]
template<typename scalar_t, int unroll_factor>
__global__ void general_rms_layernorm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
@@ -80,7 +108,7 @@ __global__ void rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input[id];
x_local[cnt] = (float) input[id];
variance += x_local[cnt] * x_local[cnt];
}
variance = blockReduceSum<float>(variance);
@@ -92,7 +120,7 @@ __global__ void rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
out[id] = ((x_local[cnt] * s_variance)) * weight[idx];
out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}
@@ -140,11 +168,11 @@ __global__ void fused_add_rms_layernorm_kernel(
}
}
template<int unroll_factor>
__global__ void fused_add_rms_layernorm_kernel(
float* __restrict__ input, // [..., hidden_size]
float* __restrict__ residual, // [..., hidden_size]
const float* __restrict__ weight, // [hidden_size]
template<typename scalar_t, int unroll_factor>
__global__ void general_fused_add_rms_layernorm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
@@ -157,10 +185,10 @@ __global__ void fused_add_rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input[id];
x_local[cnt] += residual[id];
x_local[cnt] = (float) input[id];
x_local[cnt] += (float) residual[id];
variance += x_local[cnt] * x_local[cnt];
residual[id] = x_local[cnt];
residual[id] = (scalar_t) x_local[cnt];
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
@@ -171,7 +199,7 @@ __global__ void fused_add_rms_layernorm_kernel(
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input[id] = ((x_local[cnt] * s_variance)) * weight[idx];
input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
}
}
@@ -190,7 +218,8 @@ void rms_layernorm(
if (num_tokens >= 512) {
if (input.scalar_type() == at::ScalarType::Float) {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
@@ -201,7 +230,8 @@ void rms_layernorm(
num_tokens,
hidden_size);)
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
@@ -216,11 +246,12 @@ void rms_layernorm(
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
@@ -232,7 +263,8 @@ void rms_layernorm(
hidden_size);)
break;
case 2:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
@@ -244,7 +276,8 @@ void rms_layernorm(
hidden_size);)
break;
case 4:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
@@ -256,7 +289,8 @@ void rms_layernorm(
hidden_size);)
break;
case 8:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"rms_layernorm_kernel",
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
@@ -288,7 +322,8 @@ void fused_add_rms_layernorm(
if (num_tokens >= 512) {
if (input.scalar_type() == at::ScalarType::Float) {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
@@ -299,7 +334,8 @@ void fused_add_rms_layernorm(
num_tokens,
hidden_size);)
} else {
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
@@ -314,11 +350,12 @@ void fused_add_rms_layernorm(
int unroll_factor = (hidden_size + block.x - 1) / block.x;
if (input.scalar_type() != at::ScalarType::Float) {
block.x = std::min(hidden_size / 2, 1024);
int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
}
switch (unroll_factor) {
case 1:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
@@ -330,7 +367,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 2:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
@@ -342,7 +380,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 4:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
@@ -354,7 +393,8 @@ void fused_add_rms_layernorm(
hidden_size);)
break;
case 8:
DISPATCH_FLOAT_HALF_AND_BFLOAT(
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
input.element_size(),
input.scalar_type(),
"fused_add_rms_layernorm_kernel",
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(