From 8a5d526e95d1fae6ea85b76cffcbb318d29a5100 Mon Sep 17 00:00:00 2001 From: ExtremeViscent Date: Sat, 2 Apr 2022 02:29:45 +0100 Subject: [PATCH] [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu and cross_entropy.cu code style (#634) --- .../cuda_native/csrc/kernels/cross_entropy.cu | 2 +- .../csrc/kernels/dropout_kernels.cu | 42 ++++++++++++------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu index 58d26235a..b12437870 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu @@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel( const int left_idx = block_start + threadIdx.x; const int right_idx = (blockIdx.x + 1) * vocab_size; float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[2] = {0.f, 0.f}; // logit and logit exp + float sum_logits[2] = {0.f, 0.f}; // logit and logit exp int target_tid = targets[blockIdx.x]; if (target_tid == padding_idx) { diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu index 7d314c11e..9ccf09d76 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) return; + if (i * 4 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) return; + if (i * 8 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) return; + if (i * 4 >= total_count) + return; uint8_t m[4]; @@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) return; + if (i * 8 >= total_count) + return; float4 *out4 = reinterpret_cast(out); const float4 *vals_float4 = reinterpret_cast(in); @@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel( const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) return; + if (i * 4 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) return; + if (i * 8 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel( } __syncthreads(); - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); + for (int i = 1; i < 32; i <<= 1) + sum += g.shfl_down(sum, i); - if (y == 0) tile[0][x] = sum; + if (y == 0) + tile[0][x] = sum; __syncthreads(); if (threadIdx.x < 8) { @@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel( } __syncthreads(); - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + for (int i = 1; i < WARP_SIZE; i <<= 1) + sum += g.shfl_down(sum, i); - if (y == 0) tile[0][x] = sum; + if (y == 0) + tile[0][x] = sum; __syncthreads(); if (threadIdx.x < 8) { @@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel( const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) return; + if (i * 4 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) return; + if (i * 8 >= total_count) + return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel( float sum = tile[threadIdx.y][threadIdx.x]; __syncthreads(); - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + for (int i = 1; i < WARP_SIZE; i <<= 1) + sum += g.shfl_down(sum, i); - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + if (threadIdx.x == 0) + tile[0][threadIdx.y] = sum; __syncthreads(); if (threadIdx.y == 0) {