diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 0e3e4e900..8b250cb10 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -12,6 +12,34 @@ #include "../common/micros.h" #include "../common/cuda_type_utils.h" +#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ + if (DATA_SIZE == 2) { \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } else { \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + general_##__VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + } \ + // optimized for half and bf16 template __global__ void rms_layernorm_kernel( @@ -63,11 +91,11 @@ __global__ void rms_layernorm_kernel( } } -template -__global__ void rms_layernorm_kernel( - float* __restrict__ out, // [..., hidden_size] - const float* __restrict__ input, // [..., hidden_size] - const float* __restrict__ weight, // [hidden_size] +template +__global__ void general_rms_layernorm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -80,7 +108,7 @@ __global__ void rms_layernorm_kernel( #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { int id = row_offset + idx; - x_local[cnt] = input[id]; + x_local[cnt] = (float) input[id]; variance += x_local[cnt] * x_local[cnt]; } variance = blockReduceSum(variance); @@ -92,7 +120,7 @@ __global__ void rms_layernorm_kernel( #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { int id = row_offset + idx; - out[id] = ((x_local[cnt] * s_variance)) * weight[idx]; + out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; } } @@ -140,11 +168,11 @@ __global__ void fused_add_rms_layernorm_kernel( } } -template -__global__ void fused_add_rms_layernorm_kernel( - float* __restrict__ input, // [..., hidden_size] - float* __restrict__ residual, // [..., hidden_size] - const float* __restrict__ weight, // [hidden_size] +template +__global__ void general_fused_add_rms_layernorm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -157,10 +185,10 @@ __global__ void fused_add_rms_layernorm_kernel( #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { int id = row_offset + idx; - x_local[cnt] = input[id]; - x_local[cnt] += residual[id]; + x_local[cnt] = (float) input[id]; + x_local[cnt] += (float) residual[id]; variance += x_local[cnt] * x_local[cnt]; - residual[id] = x_local[cnt]; + residual[id] = (scalar_t) x_local[cnt]; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { @@ -171,7 +199,7 @@ __global__ void fused_add_rms_layernorm_kernel( #pragma unroll unroll_factor for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) { int id = row_offset + idx; - input[id] = ((x_local[cnt] * s_variance)) * weight[idx]; + input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx]; } } @@ -190,7 +218,8 @@ void rms_layernorm( if (num_tokens >= 512) { if (input.scalar_type() == at::ScalarType::Float) { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -201,7 +230,8 @@ void rms_layernorm( num_tokens, hidden_size);) } else { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -216,11 +246,12 @@ void rms_layernorm( int unroll_factor = (hidden_size + block.x - 1) / block.x; if (input.scalar_type() != at::ScalarType::Float) { block.x = std::min(hidden_size / 2, 1024); - int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; } switch (unroll_factor) { case 1: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -232,7 +263,8 @@ void rms_layernorm( hidden_size);) break; case 2: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -244,7 +276,8 @@ void rms_layernorm( hidden_size);) break; case 4: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -256,7 +289,8 @@ void rms_layernorm( hidden_size);) break; case 8: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "rms_layernorm_kernel", rms_layernorm_kernel<<>>( @@ -288,7 +322,8 @@ void fused_add_rms_layernorm( if (num_tokens >= 512) { if (input.scalar_type() == at::ScalarType::Float) { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -299,7 +334,8 @@ void fused_add_rms_layernorm( num_tokens, hidden_size);) } else { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -314,11 +350,12 @@ void fused_add_rms_layernorm( int unroll_factor = (hidden_size + block.x - 1) / block.x; if (input.scalar_type() != at::ScalarType::Float) { block.x = std::min(hidden_size / 2, 1024); - int unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; + unroll_factor = (hidden_size / 2 + block.x - 1) / block.x; } switch (unroll_factor) { case 1: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -330,7 +367,8 @@ void fused_add_rms_layernorm( hidden_size);) break; case 2: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -342,7 +380,8 @@ void fused_add_rms_layernorm( hidden_size);) break; case 4: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( @@ -354,7 +393,8 @@ void fused_add_rms_layernorm( hidden_size);) break; case 8: - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT( + input.element_size(), input.scalar_type(), "fused_add_rms_layernorm_kernel", fused_add_rms_layernorm_kernel<<>>( diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 25b2c2f43..edd92bb96 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -22,11 +22,15 @@ def setup_seed(seed): def check_inference_engine(use_engine=False, prompt_template=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - model = LlamaForCausalLM( - LlamaConfig( - vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + model = ( + LlamaForCausalLM( + LlamaConfig( + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 + ) ) - ).cuda() + .cuda() + .half() + ) model = model.eval() inputs = [ @@ -40,7 +44,7 @@ def check_inference_engine(use_engine=False, prompt_template=None): top_k = 50 if use_engine: - inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32") + inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs)