mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
Migrated project
This commit is contained in:
71
csrc/colossal_C_frontend.cpp
Normal file
71
csrc/colossal_C_frontend.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
|
||||
#include <torch/extension.h>
|
||||
|
||||
void multi_tensor_scale_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float scale);
|
||||
|
||||
void multi_tensor_sgd_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float wd,
|
||||
float momentum,
|
||||
float dampening,
|
||||
float lr,
|
||||
bool nesterov,
|
||||
bool first_run,
|
||||
bool wd_after_momentum,
|
||||
float scale);
|
||||
|
||||
void multi_tensor_adam_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float epsilon,
|
||||
const int step,
|
||||
const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay);
|
||||
|
||||
void multi_tensor_lamb_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float epsilon,
|
||||
const int step,
|
||||
const int bias_correction,
|
||||
const float weight_decay,
|
||||
const int grad_averaging,
|
||||
const int mode,
|
||||
at::Tensor global_grad_norm,
|
||||
const float max_grad_norm,
|
||||
at::optional<bool> use_nvlamb_python);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
|
||||
"Fused overflow check + scale for a list of contiguous tensors");
|
||||
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
|
||||
"Fused SGD optimizer for list of contiguous tensors");
|
||||
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
|
||||
"Compute and apply gradient update to parameters for Adam optimizer");
|
||||
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
|
||||
"Computes and apply update for LAMB optimizer");
|
||||
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
|
||||
"Computes L2 norm for a list of contiguous tensors");
|
||||
}
|
10
csrc/compat.h
Normal file
10
csrc/compat.h
Normal file
@@ -0,0 +1,10 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
177
csrc/multi_tensor_adam.cu
Normal file
177
csrc/multi_tensor_adam.cu
Normal file
@@ -0,0 +1,177 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "type_shim.h"
|
||||
#include "multi_tensor_apply.cuh"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
typedef enum
|
||||
{
|
||||
ADAM_MODE_0 = 0, // L2 regularization mode
|
||||
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
|
||||
} adamMode_t;
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
struct AdamFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<4> &tl,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float beta1_correction,
|
||||
const float beta2_correction,
|
||||
const float epsilon,
|
||||
const float lr,
|
||||
adamMode_t mode,
|
||||
const float decay)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
||||
// potentially use to pass in list of scalar
|
||||
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T *g = (T *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T *m = (T *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T *v = (T *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0;
|
||||
i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP)
|
||||
{
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
r_g[ii] = g[i];
|
||||
r_p[ii] = p[i];
|
||||
r_m[ii] = m[i];
|
||||
r_v[ii] = v[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
r_v[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
if (mode == ADAM_MODE_0)
|
||||
{ // L2
|
||||
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = next_m_unbiased / denom;
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
}
|
||||
else
|
||||
{ // weight decay
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
p[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
v[i] = r_v[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_adam_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float epsilon,
|
||||
const int step,
|
||||
const int mode,
|
||||
const int bias_correction,
|
||||
const float weight_decay)
|
||||
{
|
||||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1)
|
||||
{
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Assume single type across p,g,m1,m2 now
|
||||
DISPATCH_DOUBLE_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "adam",
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
AdamFunctor<scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
epsilon,
|
||||
lr,
|
||||
(adamMode_t)mode,
|
||||
weight_decay);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
133
csrc/multi_tensor_apply.cuh
Normal file
133
csrc/multi_tensor_apply.cuh
Normal file
@@ -0,0 +1,133 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "compat.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
// #include <iostream>
|
||||
|
||||
// This header is the one-stop shop for all your multi-tensor apply needs.
|
||||
|
||||
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
|
||||
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadata
|
||||
{
|
||||
void *addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int sizes[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
__global__ void multi_tensor_apply_kernel(
|
||||
int chunk_size,
|
||||
volatile int *noop_flag,
|
||||
T tl,
|
||||
U callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
// Hand the chunk information to the user-supplied functor to process however it likes.
|
||||
callable(chunk_size, noop_flag, tl, args...);
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
int block_size,
|
||||
int chunk_size,
|
||||
const at::Tensor &noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>> &tensor_lists,
|
||||
T callable,
|
||||
ArgTypes... args)
|
||||
{
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
|
||||
int len0 = tensor_lists[0].size();
|
||||
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
|
||||
auto ref_device = tensor_lists[0][0].device();
|
||||
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
|
||||
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
|
||||
{
|
||||
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
|
||||
for (int t = 0; t < tensor_lists[l].size(); t++)
|
||||
{
|
||||
// TODO: Print which tensor fails.
|
||||
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
|
||||
#ifdef VERSION_GE_1_5
|
||||
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
|
||||
#endif
|
||||
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
|
||||
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
|
||||
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tl.start_tensor_this_launch = 0;
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (int t = 0; t < ntensors; t++)
|
||||
{
|
||||
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++)
|
||||
{
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks_this_tensor - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
|
||||
if (tensors_full || blocks_full || last_chunk)
|
||||
{
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
|
||||
chunk_size,
|
||||
noop_flag.DATA_PTR<int>(),
|
||||
tl,
|
||||
callable,
|
||||
args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Reset. The control flow possibilities here make my brain hurt.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks_this_tensor - 1)
|
||||
{
|
||||
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
|
||||
loc_tensor_info = 0;
|
||||
tl.start_tensor_this_launch = t + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
|
||||
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
|
||||
loc_tensor_info = 1;
|
||||
tl.start_tensor_this_launch = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
455
csrc/multi_tensor_l2norm_kernel.cu
Normal file
455
csrc/multi_tensor_l2norm_kernel.cu
Normal file
@@ -0,0 +1,455 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "type_shim.h"
|
||||
#include "multi_tensor_apply.cuh"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p)
|
||||
{
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset)
|
||||
{
|
||||
typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
template <typename x_t>
|
||||
struct L2NormFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<1> &tl,
|
||||
float *output,
|
||||
float *output_per_tensor,
|
||||
bool per_tensor,
|
||||
int max_chunks_per_tensor)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
|
||||
x_t r_x[ILP];
|
||||
for (int i = 0; i < ILP; i++)
|
||||
{
|
||||
vals[i] = 0.f;
|
||||
r_x[i] = 0;
|
||||
}
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
|
||||
{
|
||||
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x)
|
||||
{
|
||||
// load
|
||||
load_store(r_x, x, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
float next = static_cast<float>(r_x[ii]);
|
||||
vals[ii] += next * next;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
float next = static_cast<float>(x[i]);
|
||||
vals[ii] += next * next;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float val = 0.f;
|
||||
for (int i = 0; i < ILP; i++)
|
||||
val += vals[i];
|
||||
|
||||
float final = reduce_block_into_lanes(s_vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
if (!isfinite(final))
|
||||
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
|
||||
output[blockIdx.x] += final;
|
||||
if (per_tensor)
|
||||
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Probably better to template, but since we are not likely to support other norm
|
||||
template <typename x_t>
|
||||
struct MaxNormFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<1> &tl,
|
||||
float *output,
|
||||
float *output_per_tensor,
|
||||
bool per_tensor,
|
||||
int max_chunks_per_tensor)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
|
||||
x_t r_x[ILP];
|
||||
for (int i = 0; i < ILP; i++)
|
||||
{
|
||||
vals[i] = 0.f;
|
||||
r_x[i] = 0;
|
||||
}
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
|
||||
{
|
||||
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x)
|
||||
{
|
||||
// load
|
||||
load_store(r_x, x, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
float next = static_cast<float>(r_x[ii]);
|
||||
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
float next = static_cast<float>(x[i]);
|
||||
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float val = 0.f;
|
||||
for (int i = 0; i < ILP; i++)
|
||||
val = fmaxf(fabsf(val), fabsf(vals[i]));
|
||||
|
||||
float final = reduce_block_into_lanes_max_op(s_vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
if (!isfinite(final))
|
||||
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
|
||||
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
|
||||
if (per_tensor)
|
||||
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * max_chunks_per_tensor + chunk_idx] = final;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
__global__ void cleanup(
|
||||
float *output,
|
||||
float *output_per_tensor,
|
||||
float *ret,
|
||||
float *ret_per_tensor,
|
||||
bool per_tensor,
|
||||
int max_chunks_per_tensor)
|
||||
{
|
||||
__shared__ float vals[512];
|
||||
|
||||
if (blockIdx.x == 0)
|
||||
{
|
||||
float val = 0;
|
||||
if (threadIdx.x < 320)
|
||||
val = output[threadIdx.x];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
*ret = sqrt(final);
|
||||
}
|
||||
|
||||
if (per_tensor)
|
||||
{
|
||||
float *output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
||||
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val += output_this_tensor[i];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
ret_per_tensor[blockIdx.x] = sqrt(final);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void cleanup_v2(
|
||||
float *output,
|
||||
float *output_per_tensor,
|
||||
float *ret,
|
||||
float *ret_per_tensor,
|
||||
bool per_tensor,
|
||||
int max_chunks_per_tensor,
|
||||
int norm_type,
|
||||
float alpha,
|
||||
float beta)
|
||||
{
|
||||
__shared__ float vals[512];
|
||||
|
||||
if (blockIdx.x == 0)
|
||||
{
|
||||
float val = 0;
|
||||
if (threadIdx.x < 320)
|
||||
val = output[threadIdx.x];
|
||||
|
||||
if (norm_type == 0)
|
||||
{
|
||||
float final = reduce_block_into_lanes_max_op(vals, val);
|
||||
if (threadIdx.x == 0)
|
||||
*ret = alpha * (*ret) + beta * final;
|
||||
}
|
||||
else
|
||||
{
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
if (threadIdx.x == 0)
|
||||
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
|
||||
}
|
||||
}
|
||||
|
||||
if (per_tensor)
|
||||
{
|
||||
float *output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
||||
|
||||
if (norm_type == 0)
|
||||
{
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));
|
||||
|
||||
float final = reduce_block_into_lanes_max_op(vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final;
|
||||
}
|
||||
else
|
||||
{
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val += output_this_tensor[i];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python)
|
||||
{
|
||||
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
|
||||
|
||||
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
|
||||
auto output = at::zeros({320}, float_options);
|
||||
|
||||
at::Tensor output_per_tensor;
|
||||
at::Tensor ret_per_tensor;
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
int max_chunks_per_tensor = -1;
|
||||
|
||||
if (per_tensor)
|
||||
{
|
||||
for (int t = 0; t < ntensors; t++)
|
||||
{
|
||||
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
if (max_chunks_this_tensor > max_chunks_per_tensor)
|
||||
max_chunks_per_tensor = max_chunks_this_tensor;
|
||||
}
|
||||
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
|
||||
ret_per_tensor = at::empty({ntensors}, float_options);
|
||||
}
|
||||
else
|
||||
{
|
||||
ret_per_tensor = at::empty({0}, float_options);
|
||||
}
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
L2NormFunctor<scalar_t_0>(),
|
||||
output.DATA_PTR<float>(),
|
||||
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
per_tensor,
|
||||
max_chunks_per_tensor);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// This involves one more small kernel launches, but will be negligible end to end.
|
||||
// I could get rid of these by hacking the functor + multi tensor harness with persistence
|
||||
// logic, but keeping it simple for now
|
||||
auto ret = at::empty({1}, output.options());
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
|
||||
output.DATA_PTR<float>(),
|
||||
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
ret.DATA_PTR<float>(),
|
||||
per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
per_tensor,
|
||||
max_chunks_per_tensor);
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
|
||||
}
|
||||
|
||||
// Compute and update grad norm
|
||||
// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
|
||||
// L-2: gn = sqrt(a * gn^2 + b * n^2)
|
||||
// L-inf: gn = a * gn + b * n
|
||||
void multi_tensor_norm_out_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::Tensor out,
|
||||
const float alpha,
|
||||
const float beta,
|
||||
const int norm_type)
|
||||
{
|
||||
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
|
||||
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");
|
||||
// we don't need global thus uses empty here
|
||||
auto output = at::empty({320}, float_options);
|
||||
|
||||
at::Tensor output_per_tensor;
|
||||
at::Tensor ret_per_tensor;
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
int max_chunks_per_tensor = -1;
|
||||
|
||||
for (int t = 0; t < ntensors; t++)
|
||||
{
|
||||
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
if (max_chunks_this_tensor > max_chunks_per_tensor)
|
||||
max_chunks_per_tensor = max_chunks_this_tensor;
|
||||
}
|
||||
|
||||
// Although it is single write then read, still need to be zero
|
||||
// Since tailing element also participate cleanup
|
||||
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
|
||||
|
||||
if (norm_type == 0)
|
||||
{
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
MaxNormFunctor<scalar_t_0>(),
|
||||
output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(),
|
||||
true,
|
||||
max_chunks_per_tensor);)
|
||||
}
|
||||
else
|
||||
{
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
L2NormFunctor<scalar_t_0>(),
|
||||
output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(),
|
||||
true,
|
||||
max_chunks_per_tensor);)
|
||||
}
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// This involves one more small kernel launches, but will be negligible end to end.
|
||||
// I could get rid of these by hacking the functor + multi tensor harness with persistence
|
||||
// logic, but keeping it simple for now
|
||||
auto ret = at::empty({1}, output.options());
|
||||
|
||||
// Adding the following device guard since it happens sometimes that the
|
||||
// tensors are on one device and the cuda stream is on another device which
|
||||
// results in ILLEGAL MEM ACCESS error.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cleanup_v2<<<ntensors, 512, 0, stream>>>(
|
||||
output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(),
|
||||
ret.DATA_PTR<float>(),
|
||||
out.DATA_PTR<float>(),
|
||||
true,
|
||||
max_chunks_per_tensor,
|
||||
norm_type,
|
||||
alpha,
|
||||
beta);
|
||||
|
||||
return;
|
||||
}
|
427
csrc/multi_tensor_lamb.cu
Normal file
427
csrc/multi_tensor_lamb.cu
Normal file
@@ -0,0 +1,427 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "type_shim.h"
|
||||
#include "multi_tensor_apply.cuh"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p)
|
||||
{
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int src_offset)
|
||||
{
|
||||
typedef typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
typedef enum
|
||||
{
|
||||
MOMENT_MODE_0 = 0, // L2 regularization mode
|
||||
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
|
||||
} adamMode_t;
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python);
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
struct LAMBStage1Functor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<4> &tl,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float beta3,
|
||||
const float beta1_correction,
|
||||
const float beta2_correction,
|
||||
const float epsilon,
|
||||
adamMode_t mode,
|
||||
const float decay,
|
||||
const float *global_grad_norm,
|
||||
const float max_global_grad_norm)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
|
||||
|
||||
T *g = (T *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T *m = (T *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T *v = (T *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 &&
|
||||
chunk_size % ILP == 0 &&
|
||||
is_aligned(g) &&
|
||||
is_aligned(p) &&
|
||||
is_aligned(m) &&
|
||||
is_aligned(v))
|
||||
{
|
||||
T l_g[ILP];
|
||||
T l_p[ILP];
|
||||
T l_m[ILP];
|
||||
T l_v[ILP];
|
||||
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x)
|
||||
{
|
||||
// load
|
||||
load_store(l_g, g, 0, i_start);
|
||||
if (decay != 0)
|
||||
load_store(l_p, p, 0, i_start);
|
||||
load_store(l_m, m, 0, i_start);
|
||||
load_store(l_v, v, 0, i_start);
|
||||
// unpack
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
r_g[ii] = l_g[ii];
|
||||
if (decay == 0)
|
||||
{
|
||||
r_p[ii] = MATH_T(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
r_p[ii] = l_p[ii];
|
||||
}
|
||||
r_m[ii] = l_m[ii];
|
||||
r_v[ii] = l_v[ii];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
if (mode == MOMENT_MODE_0)
|
||||
{
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
// L2 on scaled grad
|
||||
scaled_grad = scaled_grad + decay * r_p[ii];
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = next_m_unbiased / denom;
|
||||
}
|
||||
else
|
||||
{
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
l_p[ii] = r_p[ii];
|
||||
l_m[ii] = r_m[ii];
|
||||
l_v[ii] = r_v[ii];
|
||||
}
|
||||
// store
|
||||
load_store(g, l_p, i_start, 0);
|
||||
load_store(m, l_m, i_start, 0);
|
||||
load_store(v, l_v, i_start, 0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0;
|
||||
i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP)
|
||||
{
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
r_g[ii] = g[i];
|
||||
// special ?optimization? for lamb stage 1
|
||||
if (decay == 0)
|
||||
{
|
||||
r_p[ii] = MATH_T(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
r_p[ii] = p[i];
|
||||
}
|
||||
r_m[ii] = m[i];
|
||||
r_v[ii] = v[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
r_v[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
if (mode == MOMENT_MODE_0)
|
||||
{
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
// L2 on scaled grad
|
||||
scaled_grad = scaled_grad + decay * r_p[ii];
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = next_m_unbiased / denom;
|
||||
}
|
||||
else
|
||||
{
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
g[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
v[i] = r_v[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
|
||||
// It computes new parameter value.
|
||||
template <typename T>
|
||||
struct LAMBStage2Functor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<2> &tl,
|
||||
const float *per_tensor_param_norm,
|
||||
const float *per_tensor_update_norm,
|
||||
const float learning_rate,
|
||||
const float decay,
|
||||
bool use_nvlamb)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
MATH_T ratio = learning_rate;
|
||||
// nvlamb: apply adaptive learning rate to all parameters
|
||||
// otherwise, only apply to those with non-zero weight decay
|
||||
if (use_nvlamb || (decay != 0.0))
|
||||
{
|
||||
float param_norm = per_tensor_param_norm[tensor_num];
|
||||
float update_norm = per_tensor_update_norm[tensor_num];
|
||||
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
|
||||
}
|
||||
|
||||
T *update = (T *)tl.addresses[0][tensor_loc];
|
||||
update += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 &&
|
||||
chunk_size % ILP == 0 &&
|
||||
is_aligned(p) &&
|
||||
is_aligned(update))
|
||||
{
|
||||
T r_p[ILP];
|
||||
T r_update[ILP];
|
||||
for (int i_start = threadIdx.x; i_start * ILP < n && i_start * ILP < chunk_size; i_start += blockDim.x)
|
||||
{
|
||||
// load
|
||||
load_store(r_p, p, 0, i_start);
|
||||
load_store(r_update, update, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));
|
||||
}
|
||||
load_store(p, r_p, i_start, 0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i_start = 0;
|
||||
i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP)
|
||||
{
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_update[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
r_p[ii] = p[i];
|
||||
r_update[ii] = update[i];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
p[i] = r_p[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_lamb_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr,
|
||||
const float beta1,
|
||||
const float beta2,
|
||||
const float epsilon,
|
||||
const int step,
|
||||
const int bias_correction,
|
||||
const float weight_decay,
|
||||
const int grad_averaging,
|
||||
const int mode,
|
||||
at::Tensor global_grad_norm,
|
||||
const float max_grad_norm,
|
||||
at::optional<bool> use_nvlamb_python)
|
||||
{
|
||||
using namespace at;
|
||||
// Master weight and 32bit momentum(potentially changing) is not handled by this
|
||||
// So we assume every tensor are all in the same type
|
||||
|
||||
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1)
|
||||
{
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Handle grad averaging mode
|
||||
float beta3 = 1.0f;
|
||||
if (grad_averaging == 1)
|
||||
beta3 = 1 - beta1;
|
||||
|
||||
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin() + 1);
|
||||
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1, tensor_lists.begin() + 2);
|
||||
|
||||
// Compute per tensor param norm
|
||||
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
|
||||
|
||||
// We now in-place modify grad to store update before compute its norm
|
||||
// Generally this is not a issue since people modify grad in step() method all the time
|
||||
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
|
||||
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
LAMBStage1Functor<scalar_t_0>(),
|
||||
beta1,
|
||||
beta2,
|
||||
beta3, // 1-beta1 or 1 depends on averaging mode
|
||||
bias_correction1,
|
||||
bias_correction2,
|
||||
epsilon,
|
||||
(adamMode_t)mode,
|
||||
weight_decay,
|
||||
global_grad_norm.DATA_PTR<float>(),
|
||||
max_grad_norm);)
|
||||
|
||||
// Compute update norms
|
||||
auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
|
||||
|
||||
std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin() + 2);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
|
||||
multi_tensor_apply<2>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
grad_param_list,
|
||||
LAMBStage2Functor<scalar_t_0>(),
|
||||
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
|
||||
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
|
||||
lr,
|
||||
weight_decay,
|
||||
use_nvlamb);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
136
csrc/multi_tensor_scale_kernel.cu
Normal file
136
csrc/multi_tensor_scale_kernel.cu
Normal file
@@ -0,0 +1,136 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
|
||||
#include <sstream>
|
||||
|
||||
#include "type_shim.h"
|
||||
#include "multi_tensor_apply.cuh"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T* p){
|
||||
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
|
||||
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
|
||||
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
|
||||
}
|
||||
|
||||
template<typename in_t, typename out_t>
|
||||
struct ScaleFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
volatile int* noop_gmem,
|
||||
TensorListMetadata<2>& tl,
|
||||
float scale)
|
||||
{
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
in_t* in = (in_t*)tl.addresses[0][tensor_loc];
|
||||
in += chunk_idx*chunk_size;
|
||||
|
||||
out_t* out = (out_t*)tl.addresses[1][tensor_loc];
|
||||
out += chunk_idx*chunk_size;
|
||||
|
||||
n -= chunk_idx*chunk_size;
|
||||
|
||||
bool finite = true;
|
||||
in_t r_in[ILP];
|
||||
out_t r_out[ILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out))
|
||||
{
|
||||
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
|
||||
{
|
||||
// load
|
||||
load_store(r_in, in, 0 , i_start);
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
|
||||
finite = finite && isfinite(r_in[ii]);
|
||||
}
|
||||
// store
|
||||
load_store(out, r_out, i_start, 0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Non-divergent exit condition for __syncthreads, not necessary here
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
|
||||
{
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
r_in[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii*blockDim.x;
|
||||
if(i < n && i < chunk_size)
|
||||
r_in[ii] = in[i];
|
||||
}
|
||||
// note for clarification to future michael:
|
||||
// From a pure memory dependency perspective, there's likely no point unrolling
|
||||
// the write loop, since writes just fire off once their LDGs arrive.
|
||||
// Put another way, the STGs are dependent on the LDGs, but not on each other.
|
||||
// There is still compute ILP benefit from unrolling the loop though.
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
|
||||
finite = finite && isfinite(r_in[ii]);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii*blockDim.x;
|
||||
if(i < n && i < chunk_size)
|
||||
out[i] = r_out[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
if(!finite)
|
||||
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_scale_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float scale)
|
||||
{
|
||||
using namespace at;
|
||||
// The output (downscaled) type is always float.
|
||||
// If build times suffer, think about where to put this dispatch,
|
||||
// and what logic should be moved out of multi_tensor_apply.
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
|
||||
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
|
||||
multi_tensor_apply<2>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
ScaleFunctor<scalar_t_0, scalar_t_1>(),
|
||||
scale); ))
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
282
csrc/multi_tensor_sgd_kernel.cu
Normal file
282
csrc/multi_tensor_sgd_kernel.cu
Normal file
@@ -0,0 +1,282 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "compat.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
/**
|
||||
* Perform fused SGD on multiple buffers
|
||||
* N: number of tensors
|
||||
* tl[0] : gradients
|
||||
* tl[1] : weights
|
||||
* tl[2] : momentum buffers
|
||||
* tl[3] : fp16 weights (if appropriate)
|
||||
* wd : weight_decay (scalar)
|
||||
* momentum : momentum (scalar)
|
||||
* dampening : momentum dampening (scalar)
|
||||
* lr : learning rate (scalar)
|
||||
* nesterov : enable nesterov (bool)
|
||||
* first run : necessary for proper momentum handling & init
|
||||
* wd_after_momentum : apply weight decay _after_ momentum instead of before
|
||||
**/
|
||||
template <int N, typename T_grad, typename T_weight>
|
||||
struct SGDFunctor
|
||||
{
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<N> &tl,
|
||||
float wd,
|
||||
float momentum,
|
||||
float dampening,
|
||||
float lr,
|
||||
bool nesterov,
|
||||
bool first_run,
|
||||
bool wd_after_momentum,
|
||||
float scale)
|
||||
{
|
||||
// Early exit if we don't need to do anything
|
||||
if (*noop_gmem)
|
||||
return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
|
||||
grad_in += chunk_idx * chunk_size;
|
||||
|
||||
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
|
||||
weight_in += chunk_idx * chunk_size;
|
||||
|
||||
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
|
||||
mom_in += chunk_idx * chunk_size;
|
||||
|
||||
at::Half *model_weights_out = nullptr;
|
||||
if (N == 4)
|
||||
{
|
||||
model_weights_out = (at::Half *)tl.addresses[3][tensor_loc];
|
||||
model_weights_out += chunk_idx * chunk_size;
|
||||
}
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// Non-divergent exit condition for the __syncthreads
|
||||
float incoming_grads[ILP];
|
||||
float incoming_weights[ILP];
|
||||
float incoming_moms[ILP];
|
||||
for (int i_start = 0;
|
||||
i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
incoming_grads[ii] = 0;
|
||||
incoming_weights[ii] = 0;
|
||||
incoming_moms[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
|
||||
incoming_weights[ii] = static_cast<float>(weight_in[i]);
|
||||
incoming_moms[ii] = static_cast<float>(mom_in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// note for clarification to future michael:
|
||||
// From a pure memory dependency perspective, there's likely no point unrolling
|
||||
// the write loop, since writes just fire off once their LDGs arrive.
|
||||
// Put another way, the STGs are dependent on the LDGs, but not on each other.
|
||||
// There is still compute ILP benefit from unrolling the loop though.
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++)
|
||||
{
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size)
|
||||
{
|
||||
// apply weight decay before momentum if necessary
|
||||
if (wd != 0.f && !wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
|
||||
if (momentum != 0.f)
|
||||
{
|
||||
if (!first_run)
|
||||
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
|
||||
else // initialize momentums to current incoming grads
|
||||
incoming_moms[ii] = incoming_grads[ii];
|
||||
|
||||
if (nesterov)
|
||||
incoming_grads[ii] += momentum * incoming_moms[ii];
|
||||
else
|
||||
incoming_grads[ii] = incoming_moms[ii];
|
||||
}
|
||||
|
||||
// Apply WD after momentum if desired
|
||||
if (wd != 0.f && wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
|
||||
// adjust the weight and write out
|
||||
weight_in[i] += (-lr * incoming_grads[ii]);
|
||||
|
||||
// if necessary, write out an fp16 copy of the weights
|
||||
if (N == 4)
|
||||
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
|
||||
|
||||
// also write out the new momentum
|
||||
if (momentum != 0.f)
|
||||
mom_in[i] = incoming_moms[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_sgd_cuda(
|
||||
int chunk_size,
|
||||
at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float wd,
|
||||
float momentum,
|
||||
float dampening,
|
||||
float lr,
|
||||
bool nesterov,
|
||||
bool first_run,
|
||||
bool wd_after_momentum,
|
||||
float scale)
|
||||
{
|
||||
auto num_tensors = tensor_lists.size();
|
||||
auto grad_type = tensor_lists[0][0].scalar_type();
|
||||
auto weight_type = tensor_lists[1][0].scalar_type();
|
||||
|
||||
if (num_tensors == 4)
|
||||
for (int i = 0; i < tensor_lists[3].size(); i++)
|
||||
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
|
||||
"Additional output tensors should always be fp16.");
|
||||
|
||||
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
|
||||
|
||||
// We have 3 possibilities to handle here, in terms of
|
||||
// grad_type, param_type, momentum_type, requires_fp16_copy
|
||||
// 1. fp16, fp16, fp16, No
|
||||
// 2. fp32, fp32, fp32, No
|
||||
// 3. fp16, fp32, fp32, Yes
|
||||
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
|
||||
// It's easier to hardcode these possibilities than to use
|
||||
// switches etc. to handle the cross-product of cases where
|
||||
// we don't want the majority of them.
|
||||
|
||||
// Case 1. fp16, fp16, fp16, No
|
||||
if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Half &&
|
||||
num_tensors == 3)
|
||||
{
|
||||
multi_tensor_apply<3>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
SGDFunctor<3, at::Half, at::Half>(),
|
||||
wd,
|
||||
momentum,
|
||||
dampening,
|
||||
lr,
|
||||
nesterov,
|
||||
first_run,
|
||||
wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 2. fp16, fp32, fp32, No
|
||||
// else if (grad_type == at::ScalarType::Half &&
|
||||
// weight_type == at::ScalarType::Float &&
|
||||
// num_tensors == 3) {
|
||||
// multi_tensor_apply<3>(
|
||||
// BLOCK_SIZE,
|
||||
// chunk_size,
|
||||
// noop_flag,
|
||||
// tensor_lists,
|
||||
// SGDFunctor<3, at::Half, float>(),
|
||||
// wd,
|
||||
// momentum,
|
||||
// dampening,
|
||||
// lr,
|
||||
// nesterov,
|
||||
// first_run,
|
||||
// wd_after_momentum);
|
||||
// }
|
||||
// Case 2. fp32, fp32, fp32, No
|
||||
else if (grad_type == at::ScalarType::Float &&
|
||||
weight_type == at::ScalarType::Float &&
|
||||
num_tensors == 3)
|
||||
{
|
||||
multi_tensor_apply<3>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
SGDFunctor<3, float, float>(),
|
||||
wd,
|
||||
momentum,
|
||||
dampening,
|
||||
lr,
|
||||
nesterov,
|
||||
first_run,
|
||||
wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 3. fp16, fp32, fp32, Yes
|
||||
else if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Float &&
|
||||
num_tensors == 4)
|
||||
{
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
SGDFunctor<4, at::Half, float>(),
|
||||
wd,
|
||||
momentum,
|
||||
dampening,
|
||||
lr,
|
||||
nesterov,
|
||||
first_run,
|
||||
wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 4. fp32, fp32, fp32, Yes
|
||||
else if (grad_type == at::ScalarType::Float &&
|
||||
weight_type == at::ScalarType::Float &&
|
||||
num_tensors == 4)
|
||||
{
|
||||
multi_tensor_apply<4>(
|
||||
BLOCK_SIZE,
|
||||
chunk_size,
|
||||
noop_flag,
|
||||
tensor_lists,
|
||||
SGDFunctor<4, float, float>(),
|
||||
wd,
|
||||
momentum,
|
||||
dampening,
|
||||
lr,
|
||||
nesterov,
|
||||
first_run,
|
||||
wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
|
||||
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
202
csrc/type_shim.h
Normal file
202
csrc/type_shim.h
Normal file
@@ -0,0 +1,202 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
|
||||
#include <ATen/ATen.h>
|
||||
#include "compat.h"
|
||||
|
||||
// Forward/backward compatiblity hack around
|
||||
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
||||
// pending more future-proof guidance from upstream.
|
||||
// struct TypeShim
|
||||
// {
|
||||
// const at::Type& payload;
|
||||
// TypeShim(const at::Type& type) : payload(type) {}
|
||||
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
||||
// operator const at::Type&(){ return payload; };
|
||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||
// //operator at::ScalarType(){ return payload.; };
|
||||
// };
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Byte: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Double: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Double: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64)
|
||||
{
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
||||
{
|
||||
if (tid < i)
|
||||
x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result)
|
||||
{
|
||||
if (tid < lanes)
|
||||
x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64)
|
||||
{
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
||||
{
|
||||
if (tid < i)
|
||||
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
|
||||
if (share_result)
|
||||
{
|
||||
if (tid < lanes)
|
||||
x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
Reference in New Issue
Block a user