[NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978)

This commit is contained in:
binmakeswell 2022-05-16 14:09:01 +08:00
parent 18542b47fc
commit f28c021376

View File

@ -1,14 +1,15 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu // modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h> #include <assert.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "compat.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
@ -28,69 +29,53 @@
* wd_after_momentum : apply weight decay _after_ momentum instead of before * wd_after_momentum : apply weight decay _after_ momentum instead of before
**/ **/
template <int N, typename T_grad, typename T_weight> template <int N, typename T_grad, typename T_weight>
struct SGDFunctor struct SGDFunctor {
{ __device__ __forceinline__ void operator()(
__device__ __forceinline__ void operator()( int chunk_size, volatile int *noop_gmem, TensorListMetadata<N> &tl,
int chunk_size, float wd, float momentum, float dampening, float lr, bool nesterov,
volatile int *noop_gmem, bool first_run, bool wd_after_momentum, float scale) {
TensorListMetadata<N> &tl, // Early exit if we don't need to do anything
float wd, if (*noop_gmem) return;
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
// Early exit if we don't need to do anything
if (*noop_gmem)
return;
int tensor_loc = tl.block_to_tensor[blockIdx.x]; int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x]; int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc]; int n = tl.sizes[tensor_loc];
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc]; T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
grad_in += chunk_idx * chunk_size; grad_in += chunk_idx * chunk_size;
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc]; T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
weight_in += chunk_idx * chunk_size; weight_in += chunk_idx * chunk_size;
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc]; T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
mom_in += chunk_idx * chunk_size; mom_in += chunk_idx * chunk_size;
at::Half *model_weights_out = nullptr; at::Half *model_weights_out = nullptr;
if (N == 4) if (N == 4) {
{ model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
model_weights_out = (at::Half *)tl.addresses[3][tensor_loc]; model_weights_out += chunk_idx * chunk_size;
model_weights_out += chunk_idx * chunk_size; }
}
n -= chunk_idx * chunk_size; n -= chunk_idx * chunk_size;
// Non-divergent exit condition for the __syncthreads // Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP]; float incoming_grads[ILP];
float incoming_weights[ILP]; float incoming_weights[ILP];
float incoming_moms[ILP]; float incoming_moms[ILP];
for (int i_start = 0; for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
i_start += blockDim.x * ILP)
{
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) for (int ii = 0; ii < ILP; ii++) {
{ incoming_grads[ii] = 0;
incoming_grads[ii] = 0; incoming_weights[ii] = 0;
incoming_weights[ii] = 0; incoming_moms[ii] = 0;
incoming_moms[ii] = 0; int i = i_start + threadIdx.x + ii * blockDim.x;
int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) {
if (i < n && i < chunk_size) incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
{ incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale; incoming_moms[ii] = static_cast<float>(mom_in[i]);
incoming_weights[ii] = static_cast<float>(weight_in[i]); }
incoming_moms[ii] = static_cast<float>(mom_in[i]); }
}
}
// note for clarification to future michael: // note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling // From a pure memory dependency perspective, there's likely no point unrolling
@ -98,185 +83,128 @@ struct SGDFunctor
// Put another way, the STGs are dependent on the LDGs, but not on each other. // Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though. // There is still compute ILP benefit from unrolling the loop though.
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) for (int ii = 0; ii < ILP; ii++) {
{ int i = i_start + threadIdx.x + ii * blockDim.x;
int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) {
if (i < n && i < chunk_size) // apply weight decay before momentum if necessary
{ if (wd != 0.f && !wd_after_momentum)
// apply weight decay before momentum if necessary incoming_grads[ii] += wd * incoming_weights[ii];
if (wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
if (momentum != 0.f) if (momentum != 0.f) {
{ if (!first_run)
if (!first_run) incoming_moms[ii] = incoming_moms[ii] * momentum +
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; (1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii]; incoming_moms[ii] = incoming_grads[ii];
if (nesterov) if (nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii]; incoming_grads[ii] += momentum * incoming_moms[ii];
else else
incoming_grads[ii] = incoming_moms[ii]; incoming_grads[ii] = incoming_moms[ii];
} }
// Apply WD after momentum if desired // Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum) if (wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii]; incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out // adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]); weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights // if necessary, write out an fp16 copy of the weights
if (N == 4) if (N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]); model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
// also write out the new momentum // also write out the new momentum
if (momentum != 0.f) if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
mom_in[i] = incoming_moms[ii];
}
}
} }
}
} }
}
}; };
void multi_tensor_sgd_cuda( void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor noop_flag, float wd, float momentum, float dampening, float lr,
std::vector<std::vector<at::Tensor>> tensor_lists, bool nesterov, bool first_run,
float wd, bool wd_after_momentum, float scale) {
float momentum, auto num_tensors = tensor_lists.size();
float dampening, auto grad_type = tensor_lists[0][0].scalar_type();
float lr, auto weight_type = tensor_lists[1][0].scalar_type();
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if (num_tensors == 4) if (num_tensors == 4)
for (int i = 0; i < tensor_lists[3].size(); i++) for (int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16."); "Additional output tensors should always be fp16.");
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of // We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No // 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No // 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes // 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use // It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where // switches etc. to handle the cross-product of cases where
// we don't want the majority of them. // we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No // Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half && weight_type == at::ScalarType::Half && num_tensors == 3) {
num_tensors == 3) multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
{ SGDFunctor<3, at::Half, at::Half>(), wd, momentum,
multi_tensor_apply<3>( dampening, lr, nesterov, first_run, wd_after_momentum,
BLOCK_SIZE, scale);
chunk_size, }
noop_flag, // Case 2. fp16, fp32, fp32, No
tensor_lists, // else if (grad_type == at::ScalarType::Half &&
SGDFunctor<3, at::Half, at::Half>(), // weight_type == at::ScalarType::Float &&
wd, // num_tensors == 3) {
momentum, // multi_tensor_apply<3>(
dampening, // BLOCK_SIZE,
lr, // chunk_size,
nesterov, // noop_flag,
first_run, // tensor_lists,
wd_after_momentum, // SGDFunctor<3, at::Half, float>(),
scale); // wd,
} // momentum,
// Case 2. fp16, fp32, fp32, No // dampening,
// else if (grad_type == at::ScalarType::Half && // lr,
// weight_type == at::ScalarType::Float && // nesterov,
// num_tensors == 3) { // first_run,
// multi_tensor_apply<3>( // wd_after_momentum);
// BLOCK_SIZE, // }
// chunk_size, // Case 2. fp32, fp32, fp32, No
// noop_flag, else if (grad_type == at::ScalarType::Float &&
// tensor_lists, weight_type == at::ScalarType::Float && num_tensors == 3) {
// SGDFunctor<3, at::Half, float>(), multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
// wd, SGDFunctor<3, float, float>(), wd, momentum,
// momentum, dampening, lr, nesterov, first_run, wd_after_momentum,
// dampening, scale);
// lr, }
// nesterov, // Case 3. fp16, fp32, fp32, Yes
// first_run, else if (grad_type == at::ScalarType::Half &&
// wd_after_momentum); weight_type == at::ScalarType::Float && num_tensors == 4) {
// } multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
// Case 2. fp32, fp32, fp32, No SGDFunctor<4, at::Half, float>(), wd, momentum,
else if (grad_type == at::ScalarType::Float && dampening, lr, nesterov, first_run, wd_after_momentum,
weight_type == at::ScalarType::Float && scale);
num_tensors == 3) }
{ // Case 4. fp32, fp32, fp32, Yes
multi_tensor_apply<3>( else if (grad_type == at::ScalarType::Float &&
BLOCK_SIZE, weight_type == at::ScalarType::Float && num_tensors == 4) {
chunk_size, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
noop_flag, SGDFunctor<4, float, float>(), wd, momentum,
tensor_lists, dampening, lr, nesterov, first_run, wd_after_momentum,
SGDFunctor<3, float, float>(), scale);
wd, } else {
momentum, AT_ERROR(
dampening, "multi_tensor_sgd only supports some combinations of gradient & weight "
lr, "types. Given: ",
nesterov, "gradient: ", grad_type, ", weight: ", weight_type,
first_run, ", num_lists: ", num_tensors);
wd_after_momentum, }
scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }