From f28c0213769dbb2037cd08123ecbf1c8a3f5114b Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Mon, 16 May 2022 14:09:01 +0800 Subject: [PATCH] [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) --- .../csrc/multi_tensor_sgd_kernel.cu | 378 +++++++----------- 1 file changed, 153 insertions(+), 225 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu index bc30e2722..a077bc738 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -1,14 +1,15 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu +// modified from +// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu #include #include #include #include -#include "multi_tensor_apply.cuh" -#include "compat.h" - #include #include +#include "compat.h" +#include "multi_tensor_apply.cuh" + #define BLOCK_SIZE 512 #define ILP 4 @@ -28,69 +29,53 @@ * wd_after_momentum : apply weight decay _after_ momentum instead of before **/ template -struct SGDFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int *noop_gmem, - TensorListMetadata &tl, - float wd, - float momentum, - float dampening, - float lr, - bool nesterov, - bool first_run, - bool wd_after_momentum, - float scale) - { - // Early exit if we don't need to do anything - if (*noop_gmem) - return; +struct SGDFunctor { + __device__ __forceinline__ void operator()( + int chunk_size, volatile int *noop_gmem, TensorListMetadata &tl, + float wd, float momentum, float dampening, float lr, bool nesterov, + bool first_run, bool wd_after_momentum, float scale) { + // Early exit if we don't need to do anything + if (*noop_gmem) return; - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; - T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; - grad_in += chunk_idx * chunk_size; + T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; + grad_in += chunk_idx * chunk_size; - T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; - weight_in += chunk_idx * chunk_size; + T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; + weight_in += chunk_idx * chunk_size; - T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; - mom_in += chunk_idx * chunk_size; + T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; + mom_in += chunk_idx * chunk_size; - at::Half *model_weights_out = nullptr; - if (N == 4) - { - model_weights_out = (at::Half *)tl.addresses[3][tensor_loc]; - model_weights_out += chunk_idx * chunk_size; - } + at::Half *model_weights_out = nullptr; + if (N == 4) { + model_weights_out = (at::Half *)tl.addresses[3][tensor_loc]; + model_weights_out += chunk_idx * chunk_size; + } - n -= chunk_idx * chunk_size; + n -= chunk_idx * chunk_size; - // Non-divergent exit condition for the __syncthreads - float incoming_grads[ILP]; - float incoming_weights[ILP]; - float incoming_moms[ILP]; - for (int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x * ILP) - { + // Non-divergent exit condition for the __syncthreads + float incoming_grads[ILP]; + float incoming_weights[ILP]; + float incoming_moms[ILP]; + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * ILP) { #pragma unroll - for (int ii = 0; ii < ILP; ii++) - { - incoming_grads[ii] = 0; - incoming_weights[ii] = 0; - incoming_moms[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) - { - incoming_grads[ii] = static_cast(grad_in[i]) * scale; - incoming_weights[ii] = static_cast(weight_in[i]); - incoming_moms[ii] = static_cast(mom_in[i]); - } - } + for (int ii = 0; ii < ILP; ii++) { + incoming_grads[ii] = 0; + incoming_weights[ii] = 0; + incoming_moms[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + incoming_grads[ii] = static_cast(grad_in[i]) * scale; + incoming_weights[ii] = static_cast(weight_in[i]); + incoming_moms[ii] = static_cast(mom_in[i]); + } + } // note for clarification to future michael: // From a pure memory dependency perspective, there's likely no point unrolling @@ -98,185 +83,128 @@ struct SGDFunctor // Put another way, the STGs are dependent on the LDGs, but not on each other. // There is still compute ILP benefit from unrolling the loop though. #pragma unroll - for (int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) - { - // apply weight decay before momentum if necessary - if (wd != 0.f && !wd_after_momentum) - incoming_grads[ii] += wd * incoming_weights[ii]; + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + // apply weight decay before momentum if necessary + if (wd != 0.f && !wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; - if (momentum != 0.f) - { - if (!first_run) - incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; - else // initialize momentums to current incoming grads - incoming_moms[ii] = incoming_grads[ii]; + if (momentum != 0.f) { + if (!first_run) + incoming_moms[ii] = incoming_moms[ii] * momentum + + (1.f - dampening) * incoming_grads[ii]; + else // initialize momentums to current incoming grads + incoming_moms[ii] = incoming_grads[ii]; - if (nesterov) - incoming_grads[ii] += momentum * incoming_moms[ii]; - else - incoming_grads[ii] = incoming_moms[ii]; - } + if (nesterov) + incoming_grads[ii] += momentum * incoming_moms[ii]; + else + incoming_grads[ii] = incoming_moms[ii]; + } - // Apply WD after momentum if desired - if (wd != 0.f && wd_after_momentum) - incoming_grads[ii] += wd * incoming_weights[ii]; + // Apply WD after momentum if desired + if (wd != 0.f && wd_after_momentum) + incoming_grads[ii] += wd * incoming_weights[ii]; - // adjust the weight and write out - weight_in[i] += (-lr * incoming_grads[ii]); + // adjust the weight and write out + weight_in[i] += (-lr * incoming_grads[ii]); - // if necessary, write out an fp16 copy of the weights - if (N == 4) - model_weights_out[i] = static_cast(weight_in[i]); + // if necessary, write out an fp16 copy of the weights + if (N == 4) + model_weights_out[i] = static_cast(weight_in[i]); - // also write out the new momentum - if (momentum != 0.f) - mom_in[i] = incoming_moms[ii]; - } - } + // also write out the new momentum + if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; } + } } + } }; -void multi_tensor_sgd_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float wd, - float momentum, - float dampening, - float lr, - bool nesterov, - bool first_run, - bool wd_after_momentum, - float scale) -{ - auto num_tensors = tensor_lists.size(); - auto grad_type = tensor_lists[0][0].scalar_type(); - auto weight_type = tensor_lists[1][0].scalar_type(); +void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + float wd, float momentum, float dampening, float lr, + bool nesterov, bool first_run, + bool wd_after_momentum, float scale) { + auto num_tensors = tensor_lists.size(); + auto grad_type = tensor_lists[0][0].scalar_type(); + auto weight_type = tensor_lists[1][0].scalar_type(); - if (num_tensors == 4) - for (int i = 0; i < tensor_lists[3].size(); i++) - TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, - "Additional output tensors should always be fp16."); + if (num_tensors == 4) + for (int i = 0; i < tensor_lists[3].size(); i++) + TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, + "Additional output tensors should always be fp16."); - TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); + TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), + "expected noop flag to be on the same device as tensors"); - // We have 3 possibilities to handle here, in terms of - // grad_type, param_type, momentum_type, requires_fp16_copy - // 1. fp16, fp16, fp16, No - // 2. fp32, fp32, fp32, No - // 3. fp16, fp32, fp32, Yes - // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case - // It's easier to hardcode these possibilities than to use - // switches etc. to handle the cross-product of cases where - // we don't want the majority of them. + // We have 3 possibilities to handle here, in terms of + // grad_type, param_type, momentum_type, requires_fp16_copy + // 1. fp16, fp16, fp16, No + // 2. fp32, fp32, fp32, No + // 3. fp16, fp32, fp32, Yes + // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case + // It's easier to hardcode these possibilities than to use + // switches etc. to handle the cross-product of cases where + // we don't want the majority of them. - // Case 1. fp16, fp16, fp16, No - if (grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Half && - num_tensors == 3) - { - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<3, at::Half, at::Half>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - // Case 2. fp16, fp32, fp32, No - // else if (grad_type == at::ScalarType::Half && - // weight_type == at::ScalarType::Float && - // num_tensors == 3) { - // multi_tensor_apply<3>( - // BLOCK_SIZE, - // chunk_size, - // noop_flag, - // tensor_lists, - // SGDFunctor<3, at::Half, float>(), - // wd, - // momentum, - // dampening, - // lr, - // nesterov, - // first_run, - // wd_after_momentum); - // } - // Case 2. fp32, fp32, fp32, No - else if (grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && - num_tensors == 3) - { - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<3, float, float>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - // Case 3. fp16, fp32, fp32, Yes - else if (grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Float && - num_tensors == 4) - { - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<4, at::Half, float>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - // Case 4. fp32, fp32, fp32, Yes - else if (grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && - num_tensors == 4) - { - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<4, float, float>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - else - { - AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", - "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors); - } + // Case 1. fp16, fp16, fp16, No + if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Half && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor<3, at::Half, at::Half>(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } + // Case 2. fp16, fp32, fp32, No + // else if (grad_type == at::ScalarType::Half && + // weight_type == at::ScalarType::Float && + // num_tensors == 3) { + // multi_tensor_apply<3>( + // BLOCK_SIZE, + // chunk_size, + // noop_flag, + // tensor_lists, + // SGDFunctor<3, at::Half, float>(), + // wd, + // momentum, + // dampening, + // lr, + // nesterov, + // first_run, + // wd_after_momentum); + // } + // Case 2. fp32, fp32, fp32, No + else if (grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && num_tensors == 3) { + multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor<3, float, float>(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } + // Case 3. fp16, fp32, fp32, Yes + else if (grad_type == at::ScalarType::Half && + weight_type == at::ScalarType::Float && num_tensors == 4) { + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor<4, at::Half, float>(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } + // Case 4. fp32, fp32, fp32, Yes + else if (grad_type == at::ScalarType::Float && + weight_type == at::ScalarType::Float && num_tensors == 4) { + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + SGDFunctor<4, float, float>(), wd, momentum, + dampening, lr, nesterov, first_run, wd_after_momentum, + scale); + } else { + AT_ERROR( + "multi_tensor_sgd only supports some combinations of gradient & weight " + "types. Given: ", + "gradient: ", grad_type, ", weight: ", weight_type, + ", num_lists: ", num_tensors); + } - AT_CUDA_CHECK(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); } \ No newline at end of file