mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
49
extensions/csrc/cuda/colossal_C_frontend.cpp
Normal file
49
extensions/csrc/cuda/colossal_C_frontend.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
|
||||
#include <torch/extension.h>
|
||||
|
||||
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float scale);
|
||||
|
||||
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float wd, float momentum, float dampening, float lr,
|
||||
bool nesterov, bool first_run,
|
||||
bool wd_after_momentum, float scale);
|
||||
|
||||
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int mode,
|
||||
const int bias_correction, const float weight_decay,
|
||||
const float div_scale);
|
||||
|
||||
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int bias_correction,
|
||||
const float weight_decay, const int grad_averaging,
|
||||
const int mode, at::Tensor global_grad_norm,
|
||||
const float max_grad_norm,
|
||||
at::optional<bool> use_nvlamb_python);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
|
||||
"Fused overflow check + scale for a list of contiguous tensors");
|
||||
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
|
||||
"Fused SGD optimizer for list of contiguous tensors");
|
||||
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
|
||||
"Compute and apply gradient update to parameters for Adam optimizer");
|
||||
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
|
||||
"Computes and apply update for LAMB optimizer");
|
||||
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
|
||||
"Computes L2 norm for a list of contiguous tensors");
|
||||
}
|
10
extensions/csrc/cuda/compat.h
Normal file
10
extensions/csrc/cuda/compat.h
Normal file
@@ -0,0 +1,10 @@
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
446
extensions/csrc/cuda/cpu_adam.cpp
Normal file
446
extensions/csrc/cuda/cpu_adam.cpp
Normal file
@@ -0,0 +1,446 @@
|
||||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
*/
|
||||
#include "cpu_adam.h"
|
||||
|
||||
#include <math.h>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
// C++ interface
|
||||
|
||||
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
|
||||
float *_exp_avg_sq, size_t _param_size,
|
||||
bool param_half_precision, bool grad_half_precision,
|
||||
bool momentum_half_precision,
|
||||
bool variance_half_precision, float loss_scale) {
|
||||
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
|
||||
__half *params_cast_h = reinterpret_cast<__half *>(_params);
|
||||
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
|
||||
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
|
||||
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data =
|
||||
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
|
||||
AVX_Data grad_4;
|
||||
this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4);
|
||||
if (loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
|
||||
}
|
||||
AVX_Data momentum_4;
|
||||
this->simd_load(momentum_half_precision, _exp_avg + i,
|
||||
momentum_cast_h + i, momentum_4);
|
||||
|
||||
AVX_Data variance_4;
|
||||
this->simd_load(variance_half_precision, _exp_avg_sq + i,
|
||||
variance_cast_h + i, variance_4);
|
||||
|
||||
AVX_Data param_4;
|
||||
this->simd_load(param_half_precision, _params + i, params_cast_h + i,
|
||||
param_4);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
|
||||
}
|
||||
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
|
||||
momentum_4.data =
|
||||
SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
|
||||
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
|
||||
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
|
||||
variance_4.data =
|
||||
SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
|
||||
grad_4.data = SIMD_SQRT(variance_4.data);
|
||||
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
|
||||
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4.data =
|
||||
SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
|
||||
}
|
||||
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
|
||||
|
||||
this->simd_store(param_half_precision, _params + i, params_cast_h + i,
|
||||
param_4);
|
||||
this->simd_store(momentum_half_precision, _exp_avg + i,
|
||||
momentum_cast_h + i, momentum_4);
|
||||
this->simd_store(variance_half_precision, _exp_avg_sq + i,
|
||||
variance_cast_h + i, variance_4);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
|
||||
if (loss_scale > 0) {
|
||||
grad /= loss_scale;
|
||||
}
|
||||
float param =
|
||||
param_half_precision ? (float)params_cast_h[k] : _params[k];
|
||||
float momentum =
|
||||
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
|
||||
float variance = variance_half_precision ? (float)variance_cast_h[k]
|
||||
: _exp_avg_sq[k];
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad = param * _weight_decay + grad;
|
||||
}
|
||||
momentum = momentum * _betta1;
|
||||
momentum = grad * betta1_minus1 + momentum;
|
||||
|
||||
variance = variance * _betta2;
|
||||
grad = grad * grad;
|
||||
variance = grad * betta2_minus1 + variance;
|
||||
|
||||
grad = sqrt(variance);
|
||||
grad = grad * _bias_correction2 + _eps;
|
||||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param += w_decay * param;
|
||||
}
|
||||
param = grad * step_size + param;
|
||||
|
||||
if (param_half_precision)
|
||||
params_cast_h[k] = (__half)param;
|
||||
else
|
||||
_params[k] = param;
|
||||
if (momentum_half_precision)
|
||||
momentum_cast_h[k] = (__half)(momentum);
|
||||
else
|
||||
_exp_avg[k] = momentum;
|
||||
if (variance_half_precision)
|
||||
variance_cast_h[k] = (__half)(variance);
|
||||
else
|
||||
_exp_avg_sq[k] = variance;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
|
||||
float *_exp_avg_sq, size_t _param_size,
|
||||
bool param_half_precision, bool grad_half_precision,
|
||||
bool momentum_half_precision,
|
||||
bool variance_half_precision, float loss_scale) {
|
||||
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
|
||||
|
||||
__half *params_cast_h = reinterpret_cast<__half *>(_params);
|
||||
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
|
||||
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
|
||||
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data =
|
||||
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
|
||||
AVX_Data grad_4[4];
|
||||
AVX_Data momentum_4[4];
|
||||
AVX_Data variance_4[4];
|
||||
AVX_Data param_4[4];
|
||||
#pragma unroll 4
|
||||
for (int j = 0; j < 4; j++) {
|
||||
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
|
||||
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
|
||||
|
||||
if (loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
|
||||
}
|
||||
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
|
||||
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
|
||||
this->simd_load(variance_half_precision,
|
||||
_exp_avg_sq + i + SIMD_WIDTH * j,
|
||||
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
|
||||
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
|
||||
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j].data =
|
||||
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
|
||||
}
|
||||
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
|
||||
momentum_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
|
||||
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
|
||||
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
|
||||
variance_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
|
||||
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
|
||||
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
|
||||
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j].data =
|
||||
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
|
||||
}
|
||||
param_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
|
||||
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
|
||||
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
|
||||
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
|
||||
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
|
||||
this->simd_store(variance_half_precision,
|
||||
_exp_avg_sq + i + SIMD_WIDTH * j,
|
||||
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
|
||||
: _params + rounded_size),
|
||||
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
|
||||
: grads + rounded_size),
|
||||
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
|
||||
: _exp_avg + rounded_size),
|
||||
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
|
||||
: _exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size), param_half_precision,
|
||||
grad_half_precision, momentum_half_precision,
|
||||
variance_half_precision, loss_scale);
|
||||
}
|
||||
|
||||
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
|
||||
float *_exp_avg_sq, size_t _param_size,
|
||||
bool param_half_precision, bool grad_half_precision,
|
||||
bool momentum_half_precision,
|
||||
bool variance_half_precision, float loss_scale) {
|
||||
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
|
||||
__half *params_cast_h = reinterpret_cast<__half *>(_params);
|
||||
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
|
||||
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
|
||||
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
AVX_Data betta1_4;
|
||||
betta1_4.data = SIMD_SET(_betta1);
|
||||
AVX_Data betta2_4;
|
||||
betta2_4.data = SIMD_SET(_betta2);
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
AVX_Data betta1_minus1_4;
|
||||
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
AVX_Data betta2_minus1_4;
|
||||
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
|
||||
|
||||
AVX_Data bias2_sqrt;
|
||||
bias2_sqrt.data = SIMD_SET(_bias_correction2);
|
||||
|
||||
AVX_Data eps_4;
|
||||
eps_4.data = SIMD_SET(_eps);
|
||||
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
AVX_Data step_size_4;
|
||||
step_size_4.data = SIMD_SET(step_size);
|
||||
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
AVX_Data weight_decay_4;
|
||||
if (_weight_decay > 0)
|
||||
weight_decay_4.data =
|
||||
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
|
||||
AVX_Data grad_4[8];
|
||||
AVX_Data momentum_4[8];
|
||||
AVX_Data variance_4[8];
|
||||
AVX_Data param_4[8];
|
||||
#pragma unroll 8
|
||||
for (int j = 0; j < 8; j++) {
|
||||
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
|
||||
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
|
||||
|
||||
if (loss_scale > 0) {
|
||||
AVX_Data loss_scale_vec;
|
||||
loss_scale_vec.data = SIMD_SET(loss_scale);
|
||||
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
|
||||
}
|
||||
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
|
||||
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
|
||||
this->simd_load(variance_half_precision,
|
||||
_exp_avg_sq + i + SIMD_WIDTH * j,
|
||||
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
|
||||
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
|
||||
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
|
||||
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j].data =
|
||||
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
|
||||
}
|
||||
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
|
||||
momentum_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
|
||||
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
|
||||
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
|
||||
variance_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
|
||||
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
|
||||
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
|
||||
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j].data =
|
||||
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
|
||||
}
|
||||
param_4[j].data =
|
||||
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
|
||||
|
||||
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
|
||||
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
|
||||
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
|
||||
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
|
||||
this->simd_store(variance_half_precision,
|
||||
_exp_avg_sq + i + SIMD_WIDTH * j,
|
||||
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size)
|
||||
Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
|
||||
: _params + rounded_size),
|
||||
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
|
||||
: grads + rounded_size),
|
||||
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
|
||||
: _exp_avg + rounded_size),
|
||||
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
|
||||
: _exp_avg_sq + rounded_size),
|
||||
(_param_size - rounded_size), param_half_precision,
|
||||
grad_half_precision, momentum_half_precision,
|
||||
variance_half_precision, loss_scale);
|
||||
}
|
||||
|
||||
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
|
||||
float epsilon, float weight_decay,
|
||||
bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale) {
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
float *params_ptr = (float *)params_c.data_ptr();
|
||||
float *grads_ptr = (float *)grads_c.data_ptr();
|
||||
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
|
||||
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
|
||||
|
||||
this->IncrementStep(step, beta1, beta2);
|
||||
this->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
|
||||
params_c.numel(), (params.options().dtype() == at::kHalf),
|
||||
(grads.options().dtype() == at::kHalf),
|
||||
(exp_avg.options().dtype() == at::kHalf),
|
||||
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
|
||||
}
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
|
||||
.def(py::init<float, float, float, float, float, bool>())
|
||||
.def("step", &Adam_Optimizer::step);
|
||||
}
|
185
extensions/csrc/cuda/cpu_adam.h
Normal file
185
extensions/csrc/cuda/cpu_adam.h
Normal file
@@ -0,0 +1,185 @@
|
||||
/*
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <stdio.h>
|
||||
#include <torch/extension.h>
|
||||
#if (__x86_64__ || __i386__)
|
||||
#include <cpuid.h>
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
|
||||
#if defined(__AVX512__)
|
||||
#define SIMD_WIDTH 16
|
||||
#define INTV __m256i
|
||||
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
|
||||
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
|
||||
#define SIMD_SET(x) _mm512_set1_ps(x)
|
||||
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
|
||||
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
|
||||
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
|
||||
#define SIMD_LOAD_HALF(x) \
|
||||
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
|
||||
#define SIMD_STORE_HALF(x, d) \
|
||||
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
|
||||
d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#elif defined(__AVX256__) or defined(__AVX2__)
|
||||
#define SIMD_WIDTH 8
|
||||
#define INTV __m128i
|
||||
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
|
||||
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
|
||||
#define SIMD_SET(x) _mm256_set1_ps(x)
|
||||
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
|
||||
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
|
||||
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
|
||||
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
|
||||
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
|
||||
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
||||
#define SIMD_STORE_HALF(x, d) \
|
||||
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
|
||||
d, _MM_FROUND_TO_NEAREST_INT)))
|
||||
|
||||
#endif
|
||||
|
||||
union AVX_Data {
|
||||
#if defined(__AVX512__)
|
||||
__m512 data;
|
||||
#elif defined(__AVX256__) or defined(__AVX2__)
|
||||
__m256 data;
|
||||
#endif
|
||||
// float data_f[16];
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN( \
|
||||
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
|
||||
size_t _param_size, bool param_half_precision = false, \
|
||||
bool grad_half_precision = false, bool momentum_half_precision = false, \
|
||||
bool variance_half_precision = false, float loss_scale = -1);
|
||||
|
||||
class Adam_Optimizer {
|
||||
public:
|
||||
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
|
||||
float eps = 1e-8, float weight_decay = 0,
|
||||
bool adamw_mode = true)
|
||||
: _alpha(alpha),
|
||||
_betta1(betta1),
|
||||
_betta2(betta2),
|
||||
_eps(eps),
|
||||
_weight_decay(weight_decay),
|
||||
_betta1_t(1.0),
|
||||
_betta2_t(1.0),
|
||||
_step(0),
|
||||
_adamw_mode(adamw_mode) {}
|
||||
~Adam_Optimizer() {}
|
||||
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2) {
|
||||
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||
_step = step;
|
||||
_betta1 = beta1;
|
||||
_betta2 = beta2;
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
} else {
|
||||
_step++;
|
||||
if (_step != step) {
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
_step = step;
|
||||
} else {
|
||||
_betta1_t *= _betta1;
|
||||
_betta2_t *= _betta2;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void update_state(float lr, float epsilon, float weight_decay,
|
||||
bool bias_correction) {
|
||||
_alpha = lr;
|
||||
_eps = epsilon;
|
||||
_weight_decay = weight_decay;
|
||||
|
||||
_bias_correction1 = 1.0f;
|
||||
_bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
_bias_correction1 = 1 - _betta1_t;
|
||||
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
|
||||
AVX_Data &data) {
|
||||
if (is_half) {
|
||||
data.data = SIMD_LOAD_HALF(h_ptr);
|
||||
} else {
|
||||
data.data = SIMD_LOAD(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
|
||||
AVX_Data &data) {
|
||||
if (is_half) {
|
||||
SIMD_STORE_HALF(h_ptr, data.data);
|
||||
} else {
|
||||
SIMD_STORE(ptr, data.data);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale);
|
||||
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
float _betta2;
|
||||
float _eps;
|
||||
float _weight_decay;
|
||||
|
||||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
|
||||
float _bias_correction1;
|
||||
float _bias_correction2;
|
||||
|
||||
bool _adamw_mode;
|
||||
};
|
312
extensions/csrc/cuda/include/block_reduce.h
Normal file
312
extensions/csrc/cuda/include/block_reduce.h
Normal file
@@ -0,0 +1,312 @@
|
||||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Tencent/TurboTransformers
|
||||
This block_reduce_n is adapted from Tencent/TurboTransformers
|
||||
*/
|
||||
#pragma once
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
enum class ReduceType { kMax = 0, kSum };
|
||||
const unsigned int WARP_REDUCE_MASK = 0xffffffff;
|
||||
const float REDUCE_FLOAT_INF_NEG = -100000000.f;
|
||||
const float REDUCE_FLOAT_INF_POS = 100000000.f;
|
||||
const unsigned int WARP_REDUCE_SIZE = 32;
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T warpReduceSum(T val) {
|
||||
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
|
||||
return val;
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
|
||||
if (lane == 0) shared[wid] = val;
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <ReduceType Rtype, int Num>
|
||||
__inline__ __device__ void blockReduce(float *pval);
|
||||
|
||||
// use template to make code more concise
|
||||
template <ReduceType Rtype, int Num>
|
||||
__inline__ __device__ void warpReduce(float *pval);
|
||||
|
||||
// static
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32));
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceMaxOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval) = max(val0_tmp, *(pval)); \
|
||||
*(pval + 1) = max(val1_tmp, *(pval + 1));
|
||||
|
||||
WarpReduceMaxOneStep(16, 32);
|
||||
WarpReduceMaxOneStep(8, 32);
|
||||
WarpReduceMaxOneStep(4, 32);
|
||||
WarpReduceMaxOneStep(2, 32);
|
||||
WarpReduceMaxOneStep(1, 32);
|
||||
#undef WarpReduceMaxOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32);
|
||||
}
|
||||
|
||||
/*
|
||||
* Unorll for loop for warpreduce to
|
||||
* imporve instruction issue efficiency
|
||||
* ElemX means there are X numbers to be summed
|
||||
*/
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
WarpReduceSumOneStep(8, 32);
|
||||
WarpReduceSumOneStep(4, 32);
|
||||
WarpReduceSumOneStep(2, 32);
|
||||
WarpReduceSumOneStep(1, 32);
|
||||
|
||||
#undef WarpReduceSumOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
|
||||
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp; \
|
||||
*(pval + 2) += val2_tmp; \
|
||||
*(pval + 3) += val3_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
WarpReduceSumOneStep(8, 32);
|
||||
WarpReduceSumOneStep(4, 32);
|
||||
WarpReduceSumOneStep(2, 32);
|
||||
WarpReduceSumOneStep(1, 32);
|
||||
#undef WarpReduceSumOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 1>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
const int num = 2;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
const int num = 4;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 1>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
141
extensions/csrc/cuda/layer_norm_cuda.cpp
Normal file
141
extensions/csrc/cuda/layer_norm_cuda.cpp
Normal file
@@ -0,0 +1,141 @@
|
||||
/*This code from NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
#include "compat.h"
|
||||
|
||||
namespace {
|
||||
|
||||
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
|
||||
int &n2) {
|
||||
int idiff = input.ndimension() - normalized_shape.size();
|
||||
n2 = 1;
|
||||
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
|
||||
assert(input.sizes()[i + idiff] == normalized_shape[i]);
|
||||
n2 *= normalized_shape[i];
|
||||
}
|
||||
n1 = 1;
|
||||
for (int i = 0; i < idiff; ++i) {
|
||||
n1 *= input.sizes()[i];
|
||||
}
|
||||
}
|
||||
|
||||
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
|
||||
at::Tensor beta) {
|
||||
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
|
||||
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
|
||||
}
|
||||
|
||||
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
|
||||
int &n2) {
|
||||
int64_t normalized_ndim = normalized_shape.size();
|
||||
|
||||
if (normalized_ndim < 1) {
|
||||
std::stringstream ss;
|
||||
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
|
||||
<< "containing at least one element, but got normalized_shape="
|
||||
<< normalized_shape;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
auto input_shape = input.sizes();
|
||||
auto input_ndim = input.dim();
|
||||
|
||||
if (input_ndim < normalized_ndim ||
|
||||
!input_shape.slice(input_ndim - normalized_ndim)
|
||||
.equals(normalized_shape)) {
|
||||
std::stringstream ss;
|
||||
ss << "Given normalized_shape=" << normalized_shape
|
||||
<< ", expected input with shape [*";
|
||||
for (auto size : normalized_shape) {
|
||||
ss << ", " << size;
|
||||
}
|
||||
ss << "], but got input of size" << input_shape;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
|
||||
compute_n1_n2(input, normalized_shape, n1, n2);
|
||||
}
|
||||
|
||||
void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
|
||||
check_args(input, normalized_shape, n1, n2);
|
||||
check_args(normalized_shape, gamma, beta);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
|
||||
at::Tensor *input, int n1, int n2,
|
||||
at::IntArrayRef normalized_shape, at::Tensor *gamma,
|
||||
at::Tensor *beta, double epsilon);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
|
||||
at::IntArrayRef normalized_shape,
|
||||
at::Tensor gamma, at::Tensor beta,
|
||||
double epsilon) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor output =
|
||||
at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
|
||||
at::Tensor mean =
|
||||
at::empty({n1}, input.options().dtype(at::ScalarType::Float));
|
||||
at::Tensor invvar = at::empty_like(mean);
|
||||
|
||||
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape,
|
||||
&gamma, &beta, epsilon);
|
||||
|
||||
return {output, mean, invvar};
|
||||
}
|
||||
|
||||
void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
|
||||
at::Tensor *invvar, at::Tensor *input, int n1,
|
||||
int n2, at::IntArrayRef normalized_shape,
|
||||
at::Tensor *gamma, at::Tensor *beta,
|
||||
double epsilon, at::Tensor *grad_input,
|
||||
at::Tensor *grad_gamma, at::Tensor *grad_beta);
|
||||
|
||||
std::vector<at::Tensor> layer_norm_gradient_affine(
|
||||
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input,
|
||||
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta,
|
||||
double epsilon) {
|
||||
CHECK_INPUT(dout);
|
||||
CHECK_INPUT(mean);
|
||||
CHECK_INPUT(invvar);
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(gamma);
|
||||
CHECK_INPUT(beta);
|
||||
int n1, n2;
|
||||
check_args(input, normalized_shape, gamma, beta, n1, n2);
|
||||
|
||||
at::Tensor grad_input = at::empty_like(input);
|
||||
at::Tensor grad_gamma = at::empty_like(gamma);
|
||||
at::Tensor grad_beta = at::empty_like(beta);
|
||||
|
||||
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
|
||||
normalized_shape, &gamma, &beta, epsilon,
|
||||
&grad_input, &grad_gamma, &grad_beta);
|
||||
|
||||
return {grad_input, grad_gamma, grad_beta};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
|
||||
m.def("backward_affine", &layer_norm_gradient_affine,
|
||||
"LayerNorm backward (CUDA)");
|
||||
}
|
683
extensions/csrc/cuda/layer_norm_cuda_kernel.cu
Normal file
683
extensions/csrc/cuda/layer_norm_cuda_kernel.cu
Normal file
@@ -0,0 +1,683 @@
|
||||
/*This code from NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/AccumulateType.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include "ATen/cuda/DeviceUtils.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
template <typename U>
|
||||
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
|
||||
count = count + U(1);
|
||||
U delta = curr - mu;
|
||||
U lmean = mu + delta / count;
|
||||
mu = lmean;
|
||||
U delta2 = curr - lmean;
|
||||
sigma2 = sigma2 + delta * delta2;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
|
||||
U& mu, U& sigma2, U& count) {
|
||||
U delta = muB - mu;
|
||||
U nA = count;
|
||||
U nB = countB;
|
||||
count = count + countB;
|
||||
U nX = count;
|
||||
if (nX > U(0)) {
|
||||
nA = nA / nX;
|
||||
nB = nB / nX;
|
||||
mu = nA * mu + nB * muB;
|
||||
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
|
||||
} else {
|
||||
mu = U(0);
|
||||
sigma2 = U(0);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1,
|
||||
const int n2, const int i1, U& mu, U& sigma2,
|
||||
U* buf) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
U count = U(0);
|
||||
mu = U(0);
|
||||
sigma2 = U(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const T* lvals = vals + i1 * n2;
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
U curr = static_cast<U>(lvals[l + k]);
|
||||
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
U curr = static_cast<U>(lvals[l]);
|
||||
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
|
||||
U muB = WARP_SHFL(mu, srcLaneB);
|
||||
U countB = WARP_SHFL(count, srcLaneB);
|
||||
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
U* ubuf = (U*)buf;
|
||||
U* ibuf = (U*)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset &&
|
||||
threadIdx.y < 2 * offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2 * wrt_y] = mu;
|
||||
ubuf[2 * wrt_y + 1] = sigma2;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
U muB = ubuf[2 * threadIdx.y];
|
||||
U sigma2B = ubuf[2 * threadIdx.y + 1];
|
||||
U countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
ubuf[0] = mu;
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
mu = ubuf[0];
|
||||
sigma2 = ubuf[1] / U(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
mu = WARP_SHFL(mu, 0);
|
||||
sigma2 = WARP_SHFL(sigma2 / U(n2), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals,
|
||||
const int n1, const int n2, const int i1,
|
||||
float& mu, float& sigma2, float* buf) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensor is contiguous
|
||||
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
|
||||
//
|
||||
// compute variance and mean over n2
|
||||
float count = 0.0f;
|
||||
mu = float(0);
|
||||
sigma2 = float(0);
|
||||
if (i1 < n1) {
|
||||
// one warp normalizes one n1 index,
|
||||
// synchronization is implicit
|
||||
// initialize with standard Welford algorithm
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
const at::Half* lvals = vals + i1 * n2;
|
||||
int l = 8 * thrx;
|
||||
if ((((size_t)lvals) & 3) != 0) {
|
||||
// 16 bit alignment
|
||||
// first thread consumes first point
|
||||
if (thrx == 0) {
|
||||
float curr = static_cast<float>(lvals[0]);
|
||||
cuWelfordOnlineSum(curr, mu, sigma2, count);
|
||||
}
|
||||
++l;
|
||||
}
|
||||
// at this point, lvals[l] are 32 bit aligned for all threads.
|
||||
for (; l + 7 < n2; l += 8 * numx) {
|
||||
for (int k = 0; k < 8; k += 2) {
|
||||
float2 curr = __half22float2(*((__half2*)(lvals + l + k)));
|
||||
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
|
||||
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
float curr = static_cast<float>(lvals[l]);
|
||||
cuWelfordOnlineSum(curr, mu, sigma2, count);
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int l = 0; l <= 4; ++l) {
|
||||
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
|
||||
float muB = WARP_SHFL(mu, srcLaneB);
|
||||
float countB = WARP_SHFL(count, srcLaneB);
|
||||
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
|
||||
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
// threadIdx.x == 0 has correct values for each warp
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
float* ubuf = (float*)buf;
|
||||
float* ibuf = (float*)(ubuf + blockDim.y);
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.x == 0 && threadIdx.y >= offset &&
|
||||
threadIdx.y < 2 * offset) {
|
||||
const int wrt_y = threadIdx.y - offset;
|
||||
ubuf[2 * wrt_y] = mu;
|
||||
ubuf[2 * wrt_y + 1] = sigma2;
|
||||
ibuf[wrt_y] = count;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.x == 0 && threadIdx.y < offset) {
|
||||
float muB = ubuf[2 * threadIdx.y];
|
||||
float sigma2B = ubuf[2 * threadIdx.y + 1];
|
||||
float countB = ibuf[threadIdx.y];
|
||||
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
ubuf[0] = mu;
|
||||
ubuf[1] = sigma2;
|
||||
}
|
||||
__syncthreads();
|
||||
mu = ubuf[0];
|
||||
sigma2 = ubuf[1] / float(n2);
|
||||
// don't care about final value of count, we know count == n2
|
||||
} else {
|
||||
mu = WARP_SHFL(mu, 0);
|
||||
sigma2 = WARP_SHFL(sigma2 / float(n2), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
U rsqrt(U v) {
|
||||
return U(1) / sqrt(v);
|
||||
}
|
||||
template <>
|
||||
float rsqrt(float v) {
|
||||
return rsqrtf(v);
|
||||
}
|
||||
template <>
|
||||
double rsqrt(double v) {
|
||||
return rsqrt(v);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// This is the un-specialized struct. Note that we prevent instantiation of
|
||||
// this struct by putting an undefined symbol in the function body so it won't
|
||||
// compile.
|
||||
// template <typename T>
|
||||
// struct SharedMemory
|
||||
// {
|
||||
// // Ensure that we won't compile any un-specialized types
|
||||
// __device__ T *getPointer()
|
||||
// {
|
||||
// extern __device__ void error(void);
|
||||
// error();
|
||||
// return NULL;
|
||||
// }
|
||||
// };
|
||||
// https://github.com/NVIDIA/apex/issues/246
|
||||
template <typename T>
|
||||
struct SharedMemory;
|
||||
|
||||
template <>
|
||||
struct SharedMemory<float> {
|
||||
__device__ float* getPointer() {
|
||||
extern __shared__ float s_float[];
|
||||
return s_float;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__global__ void cuApplyLayerNorm(V* __restrict__ output_vals,
|
||||
U* __restrict__ mean, U* __restrict__ invvar,
|
||||
const T* __restrict__ vals, const int n1,
|
||||
const int n2, const U epsilon,
|
||||
const V* __restrict__ gamma,
|
||||
const V* __restrict__ beta) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == warpSize
|
||||
// 2) Tensors are contiguous
|
||||
//
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
U mu, sigma2;
|
||||
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);
|
||||
const T* lvals = vals + i1 * n2;
|
||||
V* ovals = output_vals + i1 * n2;
|
||||
U c_invvar = rsqrt(sigma2 + epsilon);
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL && beta != NULL) {
|
||||
for (int i = thrx; i < n2; i += numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = thrx; i < n2; i += numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
mean[i1] = mu;
|
||||
invvar[i1] = c_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__device__ void cuLoadWriteStridedInputs(
|
||||
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
|
||||
const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,
|
||||
const T* input, const V* dout, const int i1_end, const int n2,
|
||||
const U* __restrict__ mean, const U* __restrict__ invvar) {
|
||||
int i1 = i1_block + thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1 * n2 + i2;
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (i2 < n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] = curr_dout;
|
||||
warp_buf2[write_idx] =
|
||||
curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
} else {
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
warp_buf1[write_idx] = U(0);
|
||||
warp_buf2[write_idx] = U(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__device__ void cuLoadAddStridedInputs(
|
||||
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
|
||||
const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,
|
||||
const T* input, const V* dout, const int i1_end, const int n2,
|
||||
const U* __restrict__ mean, const U* __restrict__ invvar) {
|
||||
int i1 = i1_block + thr_load_row_off;
|
||||
if (i1 < i1_end) {
|
||||
U curr_mean = mean[i1];
|
||||
U curr_invvar = invvar[i1];
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int i2 = i2_off + k;
|
||||
int load_idx = i1 * n2 + i2;
|
||||
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
|
||||
if (i2 < n2) {
|
||||
U curr_input = static_cast<U>(input[load_idx]);
|
||||
U curr_dout = static_cast<U>(dout[load_idx]);
|
||||
warp_buf1[write_idx] += curr_dout;
|
||||
warp_buf2[write_idx] +=
|
||||
curr_dout * (curr_input - curr_mean) * curr_invvar;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__global__ void cuComputePartGradGammaBeta(
|
||||
const V* __restrict__ dout, const T* __restrict__ input, const int n1,
|
||||
const int n2, const U* __restrict__ mean, const U* __restrict__ invvar,
|
||||
U epsilon, U* part_grad_gamma, U* part_grad_beta) {
|
||||
const int numsegs_n1 =
|
||||
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
|
||||
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
|
||||
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
|
||||
const int i1_beg_plus_one =
|
||||
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
|
||||
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
|
||||
const int row_stride = blockDim.x + 1;
|
||||
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
|
||||
const int thr_load_row_off =
|
||||
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
|
||||
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
|
||||
// blockDim.y + (blockDim.y -
|
||||
// 1)*(blockDim.x/blockDim.y) elements
|
||||
U* warp_buf1 = (U*)buf;
|
||||
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
|
||||
// compute partial sums from strided inputs
|
||||
// do this to increase number of loads in flight
|
||||
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
|
||||
row_stride, warp_buf1, warp_buf2, input, dout,
|
||||
i1_end, n2, mean, invvar);
|
||||
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
|
||||
i1_block += blockDim.y * blockDim.y) {
|
||||
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
|
||||
row_stride, warp_buf1, warp_buf2, input, dout,
|
||||
i1_end, n2, mean, invvar);
|
||||
}
|
||||
__syncthreads();
|
||||
// inter-warp reductions
|
||||
// sum within each warp
|
||||
U acc1 = U(0);
|
||||
U acc2 = U(0);
|
||||
for (int k = 0; k < blockDim.y; ++k) {
|
||||
int row1 = threadIdx.y + k * blockDim.y;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
acc1 += warp_buf1[idx1];
|
||||
acc2 += warp_buf2[idx1];
|
||||
}
|
||||
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
|
||||
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
|
||||
__syncthreads();
|
||||
// sum all warps
|
||||
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
|
||||
if (threadIdx.y < offset) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + offset;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
int idx2 = row2 * row_stride + threadIdx.x;
|
||||
warp_buf1[idx1] += warp_buf1[idx2];
|
||||
warp_buf2[idx1] += warp_buf2[idx2];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (threadIdx.y == 0 && i2 < n2) {
|
||||
int row1 = threadIdx.y;
|
||||
int row2 = threadIdx.y + 1;
|
||||
int idx1 = row1 * row_stride + threadIdx.x;
|
||||
int idx2 = row2 * row_stride + threadIdx.x;
|
||||
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
|
||||
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, typename V>
|
||||
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
|
||||
const U* part_grad_beta,
|
||||
const int part_size, const int n1,
|
||||
const int n2, V* grad_gamma,
|
||||
V* grad_beta) {
|
||||
// sum partial gradients for gamma and beta
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i2 < n2) {
|
||||
// each warp does sequential reductions until reduced part_size is num_warps
|
||||
int num_warp_reductions = part_size / blockDim.y;
|
||||
U sum_gamma = U(0);
|
||||
U sum_beta = U(0);
|
||||
const U* part_grad_gamma_ptr =
|
||||
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
const U* part_grad_beta_ptr =
|
||||
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
|
||||
for (int warp_offset = 0; warp_offset < num_warp_reductions;
|
||||
++warp_offset) {
|
||||
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
|
||||
sum_beta += part_grad_beta_ptr[warp_offset * n2];
|
||||
}
|
||||
// inter-warp reductions
|
||||
const int nbsize3 = blockDim.x * blockDim.y / 2;
|
||||
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
|
||||
// top half write to shared memory
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
|
||||
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[write_idx] = sum_gamma;
|
||||
buf[write_idx + nbsize3] = sum_beta;
|
||||
}
|
||||
__syncthreads();
|
||||
// bottom half sums
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_gamma += buf[read_idx];
|
||||
sum_beta += buf[read_idx + nbsize3];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// write out fully summed gradients
|
||||
if (threadIdx.y == 0) {
|
||||
grad_gamma[i2] = sum_gamma;
|
||||
grad_beta[i2] = sum_beta;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
__global__ void cuComputeGradInput(const V* __restrict__ dout,
|
||||
const T* __restrict__ input, const int n1,
|
||||
const int n2, const U* __restrict__ mean,
|
||||
const U* __restrict__ invvar, U epsilon,
|
||||
const V* gamma, T* grad_input) {
|
||||
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
|
||||
U sum_loss1 = U(0);
|
||||
U sum_loss2 = U(0);
|
||||
const U c_mean = mean[i1];
|
||||
const U c_invvar = invvar[i1];
|
||||
const T* k_input = input + i1 * n2;
|
||||
const V* k_dout = dout + i1 * n2;
|
||||
const int numx = blockDim.x * blockDim.y;
|
||||
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
if (gamma != NULL) {
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l + k]);
|
||||
sum_loss1 += c_loss * gamma[l + k];
|
||||
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss * gamma[l];
|
||||
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
} else {
|
||||
int l = 4 * thrx;
|
||||
for (; l + 3 < n2; l += 4 * numx) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
const U c_h = static_cast<U>(k_input[l + k]);
|
||||
const U c_loss = static_cast<U>(k_dout[l + k]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
for (; l < n2; ++l) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
sum_loss1 += c_loss;
|
||||
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
|
||||
}
|
||||
}
|
||||
// intra-warp reductions
|
||||
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
|
||||
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
|
||||
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
|
||||
}
|
||||
// inter-warp reductions
|
||||
if (blockDim.y > 1) {
|
||||
SharedMemory<U> shared;
|
||||
U* buf = shared.getPointer();
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
|
||||
// upper half of warps write to shared
|
||||
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
|
||||
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
|
||||
buf[2 * wrt_i] = sum_loss1;
|
||||
buf[2 * wrt_i + 1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
// lower half merges
|
||||
if (threadIdx.y < offset) {
|
||||
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
sum_loss1 += buf[2 * read_i];
|
||||
sum_loss2 += buf[2 * read_i + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.y == 0) {
|
||||
buf[2 * threadIdx.x] = sum_loss1;
|
||||
buf[2 * threadIdx.x + 1] = sum_loss2;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.y != 0) {
|
||||
sum_loss1 = buf[2 * threadIdx.x];
|
||||
sum_loss2 = buf[2 * threadIdx.x + 1];
|
||||
}
|
||||
}
|
||||
// all threads now have the two sums over l
|
||||
U fH = (U)n2;
|
||||
U term1 = (U(1) / fH) * c_invvar;
|
||||
T* k_grad_input = grad_input + i1 * n2;
|
||||
if (gamma != NULL) {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss * gamma[l];
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
} else {
|
||||
for (int l = thrx; l < n2; l += numx) {
|
||||
const U c_h = static_cast<U>(k_input[l]);
|
||||
const U c_loss = static_cast<U>(k_dout[l]);
|
||||
U f_grad_input = fH * c_loss;
|
||||
f_grad_input -= sum_loss1;
|
||||
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
|
||||
f_grad_input *= term1;
|
||||
k_grad_input[l] = static_cast<T>(f_grad_input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1,
|
||||
int n2, double epsilon, const V* gamma, const V* beta) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const dim3 threads(32, 4, 1);
|
||||
const uint64_t maxGridY =
|
||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
int nshared =
|
||||
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
|
||||
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
|
||||
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
|
||||
}
|
||||
|
||||
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar,
|
||||
at::Tensor* input, int n1, int n2,
|
||||
#ifdef VERSION_GE_1_1
|
||||
at::IntArrayRef normalized_shape,
|
||||
#else
|
||||
at::IntList normalized_shape,
|
||||
#endif
|
||||
at::Tensor* gamma, at::Tensor* beta, double epsilon) {
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
|
||||
HostApplyLayerNorm(output->DATA_PTR<scalar_t_out>(),
|
||||
mean->DATA_PTR<float>(), invvar->DATA_PTR<float>(),
|
||||
input->DATA_PTR<scalar_t_in>(), n1, n2, epsilon,
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);)
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar,
|
||||
at::Tensor* input, int n1, int n2, const V* gamma,
|
||||
const V* beta, double epsilon, T* grad_input,
|
||||
V* grad_gamma, V* grad_beta) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
if (gamma != NULL && beta != NULL) {
|
||||
// compute grad_gamma(j) and grad_beta(j)
|
||||
const int part_size = 16;
|
||||
const dim3 threads2(32, 4, 1);
|
||||
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
|
||||
const int nshared2_a =
|
||||
2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
|
||||
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
|
||||
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
|
||||
at::Tensor part_grad_gamma = at::empty(
|
||||
{part_size, n2}, input->options().dtype(at::ScalarType::Float));
|
||||
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
|
||||
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
|
||||
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon),
|
||||
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>());
|
||||
|
||||
const dim3 threads3(32, 8, 1);
|
||||
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
|
||||
const int nshared3 = threads3.x * threads3.y * sizeof(U);
|
||||
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
|
||||
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size,
|
||||
n1, n2, grad_gamma, grad_beta);
|
||||
}
|
||||
|
||||
// compute grad_input
|
||||
const uint64_t maxGridY =
|
||||
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
|
||||
const dim3 threads1(32, 4, 1);
|
||||
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
|
||||
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
|
||||
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon), gamma,
|
||||
grad_input);
|
||||
}
|
||||
|
||||
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean,
|
||||
at::Tensor* invvar, at::Tensor* input, int n1,
|
||||
int n2,
|
||||
#ifdef VERSION_GE_1_1
|
||||
at::IntArrayRef normalized_shape,
|
||||
#else
|
||||
at::IntList normalized_shape,
|
||||
#endif
|
||||
at::Tensor* gamma, at::Tensor* beta,
|
||||
double epsilon, at::Tensor* grad_input,
|
||||
at::Tensor* grad_gamma, at::Tensor* grad_beta) {
|
||||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), gamma->scalar_type(),
|
||||
"cuda_layer_norm_gradient_kernel",
|
||||
HostLayerNormGradient(
|
||||
dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(),
|
||||
invvar->DATA_PTR<float>(), input, n1, n2,
|
||||
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
|
||||
// if gamma Tensor is NULL on input.
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon,
|
||||
grad_input->DATA_PTR<scalar_t_in>(),
|
||||
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);)
|
||||
}
|
97
extensions/csrc/cuda/moe_cuda.cpp
Normal file
97
extensions/csrc/cuda/moe_cuda.cpp
Normal file
@@ -0,0 +1,97 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
||||
|
||||
#define CHECK_CUDA(x) \
|
||||
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask, torch::Tensor dest_idx) {
|
||||
CHECK_INPUT(batch_tokens);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_backward(int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
CHECK_INPUT(expert_grad);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_combine_forward(int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
CHECK_INPUT(expert_tokens);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask,
|
||||
dest_idx);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
|
||||
torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
CHECK_INPUT(tokens_grad);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens,
|
||||
logits, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_cumsum(torch::Tensor mask) {
|
||||
CHECK_INPUT(mask);
|
||||
return cumsum_sub_one_in_dim0(mask);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0");
|
||||
m.def("dispatch_forward", &moe_dispatch_forward,
|
||||
"Forward operation in MoE dispatch function");
|
||||
m.def("dispatch_backward", &moe_dispatch_backward,
|
||||
"Backward operation in MoE dispatch function");
|
||||
m.def("combine_forward", &moe_combine_forward,
|
||||
"Combine operation in MoE combine function");
|
||||
m.def("combine_backward", &moe_combine_backward,
|
||||
"Combine operation in MoE combine function");
|
||||
}
|
659
extensions/csrc/cuda/moe_cuda_kernel.cu
Normal file
659
extensions/csrc/cuda/moe_cuda_kernel.cu
Normal file
@@ -0,0 +1,659 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "block_reduce.h"
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, pack);
|
||||
BlockStore(ts_store).Store(src_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row1 + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row2 + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack1[pack_size], pack2[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
|
||||
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack1[i] += pack2[i];
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row + idx, pack1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
|
||||
const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack[i] *= weight;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
|
||||
T *weight_grad, const T weight, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T grad[pack_size], tokens[pack_size];
|
||||
float thread_sum = 0;
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, grad);
|
||||
BlockLoad(ts_load).Load(tks_row + idx, tokens);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
thread_sum += grad[i] * tokens[i];
|
||||
grad[i] *= weight;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row + idx, grad);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 1>(&thread_sum);
|
||||
|
||||
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
|
||||
const T weight1, const T weight2,
|
||||
const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack1[pack_size], pack2[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row1 + idx, pack1);
|
||||
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
|
||||
T *tks_row1, T *tks_row2, T *weight_grad1,
|
||||
T *weight_grad2, const T weight1,
|
||||
const T weight2, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
|
||||
BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
|
||||
BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size],
|
||||
sgrad2[pack_size];
|
||||
float thread_sum[2] = {0, 0};
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, grad);
|
||||
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
|
||||
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
thread_sum[0] += grad[i] * tokens1[i];
|
||||
thread_sum[1] += grad[i] * tokens2[i];
|
||||
sgrad1[i] = weight1 * grad[i];
|
||||
sgrad2[i] = weight2 * grad[i];
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
|
||||
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 2>(thread_sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
*weight_grad1 = static_cast<T>(thread_sum[0]);
|
||||
else if (threadIdx.x == 1)
|
||||
*weight_grad2 = static_cast<T>(thread_sum[1]);
|
||||
}
|
||||
|
||||
// DISPATCH KERNELS --------------------------------
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols, const int indicator1,
|
||||
const int indicator2) {
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
||||
cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
||||
const int cols, const int indicator1,
|
||||
const int indicator2) {
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
||||
cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
|
||||
int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int h) {
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_fwd_selector<T, block_size, pack_size>(
|
||||
batch_tokens + (row * h), expert_input + (dest1[row] * h),
|
||||
expert_input + (dest2[row] * h), h, mask1[row], indicator2);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2,
|
||||
const int h) {
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_bwd_selector<T, block_size, pack_size>(
|
||||
tokens_grad + (row * h), expert_grad + (dest1[row] * h),
|
||||
expert_grad + (dest2[row] * h), h, mask1[row], indicator2);
|
||||
}
|
||||
|
||||
// COMBINE KERNELS --------------------------------
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
||||
const int cols, const T weight1,
|
||||
const T weight2, const int indicator1,
|
||||
const int indicator2) {
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
||||
weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(src_row1, dst_row, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(src_row2, dst_row, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
||||
const int cols, T *tks_row1, T *tks_row2,
|
||||
T *wt_grad1, T *wt_grad2, const T weight1,
|
||||
const T weight2, const int indicator1,
|
||||
const int indicator2) {
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
||||
tks_row1, tks_row2, wt_grad1,
|
||||
wt_grad2, weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(src_row1, dst_row, tks_row1,
|
||||
wt_grad1, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(src_row2, dst_row, tks_row2,
|
||||
wt_grad2, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
|
||||
T *logits, int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int e, const int c,
|
||||
const int h) {
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e);
|
||||
moe_cb_fwd_selector<T, block_size, pack_size>(
|
||||
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
|
||||
combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row],
|
||||
indicator2);
|
||||
}
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
|
||||
T *logits, T *logits_grad, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2,
|
||||
const int e, const int c, const int h) {
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
|
||||
moe_cb_bwd_selector<T, block_size, pack_size>(
|
||||
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
|
||||
tokens_grad + (row * h), h, tks + (dest1[row] * h),
|
||||
tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1],
|
||||
row_log[eid2], mask1[row], indicator2);
|
||||
}
|
||||
|
||||
// CUMSUM KERNEL --------------------------------
|
||||
|
||||
template <int block_size, int pack_size>
|
||||
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
|
||||
const int e) {
|
||||
assert(s % pack_size == 0);
|
||||
constexpr int bpack_size = block_size * pack_size;
|
||||
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
|
||||
__shared__ int temp[block_size + 1];
|
||||
int pack[pack_size];
|
||||
|
||||
for (int idx = 0; idx < s; idx += bpack_size) {
|
||||
int offset = 1;
|
||||
|
||||
if (idx + tps < s) {
|
||||
temp[tid] = inputs[tps * e + bid];
|
||||
#pragma unroll
|
||||
for (int i = 1; i < pack_size; ++i) {
|
||||
pack[i] = inputs[(tps + i) * e + bid];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 1; i < pack_size; ++i) {
|
||||
temp[tid] += pack[i];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = block_size >> 1; i > 0; i >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < i) {
|
||||
int j = offset * (2 * tid + 1) - 1;
|
||||
temp[j + offset] += temp[j];
|
||||
}
|
||||
offset <<= 1;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
temp[block_size] = temp[block_size - 1];
|
||||
temp[block_size - 1] = 0;
|
||||
}
|
||||
|
||||
for (int i = 1; i < block_size; i <<= 1) {
|
||||
offset >>= 1;
|
||||
__syncthreads();
|
||||
if (tid < i) {
|
||||
int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j];
|
||||
temp[j] = temp[k];
|
||||
temp[k] += ts;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) temp[0] = temp[block_size];
|
||||
__syncthreads();
|
||||
|
||||
if (idx + tps < s) {
|
||||
temp[tid + 1] += last_sum;
|
||||
#pragma unroll
|
||||
for (int i = pack_size - 1; i > 0; --i) {
|
||||
outputs[(tps + i) * e + bid] = temp[tid + 1];
|
||||
temp[tid + 1] -= pack[i];
|
||||
}
|
||||
outputs[tps * e + bid] = temp[tid + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
last_sum += temp[0];
|
||||
inputs += bpack_size * e;
|
||||
outputs += bpack_size * e;
|
||||
}
|
||||
}
|
||||
|
||||
// LAUNCH FUNCTIONS --------------------------------
|
||||
|
||||
template <typename T>
|
||||
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
|
||||
int *mask2, int *dest1, int *dest2, const int s,
|
||||
const int h) {
|
||||
if (h < 256)
|
||||
moe_dpch_fwd_kernel<T, 32, 4>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 512)
|
||||
moe_dpch_fwd_kernel<T, 32, 8>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 1024)
|
||||
moe_dpch_fwd_kernel<T, 32, 16>
|
||||
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 2048)
|
||||
moe_dpch_fwd_kernel<T, 64, 16>
|
||||
<<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else
|
||||
moe_dpch_fwd_kernel<T, 128, 16>
|
||||
<<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
|
||||
int *dest1, int *dest2, const int s, const int h) {
|
||||
if (h < 256)
|
||||
moe_dpch_bwd_kernel<T, 32, 4>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 512)
|
||||
moe_dpch_bwd_kernel<T, 32, 8>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 1024)
|
||||
moe_dpch_bwd_kernel<T, 32, 16>
|
||||
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 2048)
|
||||
moe_dpch_bwd_kernel<T, 64, 16>
|
||||
<<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else
|
||||
moe_dpch_bwd_kernel<T, 128, 16>
|
||||
<<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
|
||||
int *mask1, int *mask2, int *dest1, int *dest2,
|
||||
const int s, const int e, const int c, const int h) {
|
||||
if (h < 256)
|
||||
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 512)
|
||||
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 1024)
|
||||
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else if (h < 2048)
|
||||
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1, dest2,
|
||||
e, c, h);
|
||||
else
|
||||
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>(expert_tokens, combine_tokens,
|
||||
logits, mask1, mask2, dest1,
|
||||
dest2, e, c, h);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
|
||||
T *logits_grad, int *mask1, int *mask2, int *dest1,
|
||||
int *dest2, const int s, const int e, const int c,
|
||||
const int h) {
|
||||
if (h < 256)
|
||||
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
|
||||
logits, logits_grad, mask1, mask2,
|
||||
dest1, dest2, e, c, h);
|
||||
else // if (h < 512)
|
||||
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
|
||||
logits, logits_grad, mask1, mask2,
|
||||
dest1, dest2, e, c, h);
|
||||
// else if (h < 1024)
|
||||
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
|
||||
// dest1, dest2, e, c, h);
|
||||
// else
|
||||
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
|
||||
// dest1, dest2, e, c, h);
|
||||
}
|
||||
|
||||
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
|
||||
if (s <= 256)
|
||||
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
|
||||
else if (s <= 512)
|
||||
cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e);
|
||||
else if (s <= 1024)
|
||||
cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
else if (s <= 2048)
|
||||
cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
else
|
||||
cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
}
|
||||
|
||||
// API FUNCTIONS --------------------------------
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros(
|
||||
{ec, h},
|
||||
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
batch_tokens.scalar_type(), "moe dispatch forward",
|
||||
moe_dpch_fwd_launch<scalar_t>(
|
||||
batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros(
|
||||
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_grad.scalar_type(), "moe dispatch backward",
|
||||
moe_dpch_bwd_launch<scalar_t>(
|
||||
res.data<scalar_t>(), expert_grad.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
assert(h % 16 == 0);
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
auto res = torch::zeros(
|
||||
{s, h},
|
||||
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_tokens.scalar_type(), "moe combine forward",
|
||||
moe_cb_fwd_launch<scalar_t>(
|
||||
expert_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
logits.data<scalar_t>(), mask[0].data<int>(),
|
||||
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
|
||||
h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h, torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
assert(h % 16 == 0);
|
||||
assert(tokens_grad.dtype() == expert_tokens.dtype());
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
auto egrad = torch::zeros(
|
||||
{e * c, h},
|
||||
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
|
||||
wgrad = torch::zeros(
|
||||
{s, e}, torch::dtype(logits.dtype()).device(logits.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tokens_grad.scalar_type(), "moe combine backward",
|
||||
moe_cb_bwd_launch<scalar_t>(
|
||||
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(),
|
||||
expert_tokens.data<scalar_t>(), logits.data<scalar_t>(),
|
||||
wgrad.data<scalar_t>(), mask[0].data<int>(),
|
||||
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
|
||||
h));
|
||||
|
||||
return {egrad, wgrad};
|
||||
}
|
||||
|
||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
|
||||
assert(mask.dim() == 2);
|
||||
assert(mask.dtype() == torch::kInt32);
|
||||
|
||||
const int s = mask.size(0), e = mask.size(1);
|
||||
auto res =
|
||||
torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
|
||||
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
|
||||
|
||||
return res;
|
||||
}
|
146
extensions/csrc/cuda/multi_tensor_adam.cu
Normal file
146
extensions/csrc/cuda/multi_tensor_adam.cu
Normal file
@@ -0,0 +1,146 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
typedef enum {
|
||||
ADAM_MODE_0 = 0, // L2 regularization mode
|
||||
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
|
||||
} adamMode_t;
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T_g, typename T_p>
|
||||
struct AdamFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||
const float beta1, const float beta2, const float beta1_correction,
|
||||
const float beta2_correction, const float epsilon, const float lr,
|
||||
adamMode_t mode, const float decay, const float div_scale) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
||||
// potentially use to pass in list of scalar
|
||||
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
r_g[ii] = g[i];
|
||||
r_p[ii] = p[i];
|
||||
r_m[ii] = m[i];
|
||||
r_v[ii] = v[i];
|
||||
} else {
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
r_v[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (div_scale > 0) r_g[ii] /= div_scale;
|
||||
|
||||
if (mode == ADAM_MODE_0) { // L2
|
||||
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = next_m_unbiased / denom;
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
} else { // weight decay
|
||||
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
|
||||
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
r_p[ii] = r_p[ii] - (lr * update);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
p[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
v[i] = r_v[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int mode,
|
||||
const int bias_correction, const float weight_decay,
|
||||
const float div_scale) {
|
||||
using namespace at;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
|
||||
tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0,
|
||||
"adam",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
|
||||
beta2, bias_correction1, bias_correction2, epsilon,
|
||||
lr, (adamMode_t)mode, weight_decay, div_scale);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
130
extensions/csrc/cuda/multi_tensor_apply.cuh
Normal file
130
extensions/csrc/cuda/multi_tensor_apply.cuh
Normal file
@@ -0,0 +1,130 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <assert.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "compat.h"
|
||||
|
||||
// #include <iostream>
|
||||
|
||||
// This header is the one-stop shop for all your multi-tensor apply needs.
|
||||
|
||||
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
|
||||
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadata {
|
||||
void *addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int sizes[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a
|
||||
// full int.
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
__global__ void multi_tensor_apply_kernel(int chunk_size,
|
||||
volatile int *noop_flag, T tl,
|
||||
U callable, ArgTypes... args) {
|
||||
// Hand the chunk information to the user-supplied functor to process however
|
||||
// it likes.
|
||||
callable(chunk_size, noop_flag, tl, args...);
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
int block_size, int chunk_size, const at::Tensor &noop_flag,
|
||||
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
|
||||
ArgTypes... args) {
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
|
||||
int len0 = tensor_lists[0].size();
|
||||
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
|
||||
auto ref_device = tensor_lists[0][0].device();
|
||||
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
|
||||
for (int l = 0; l < tensor_lists.size();
|
||||
l++) // No range-based for because I need indices
|
||||
{
|
||||
TORCH_CHECK(tensor_lists[l].size() == len0,
|
||||
"Size mismatch among tensor lists");
|
||||
for (int t = 0; t < tensor_lists[l].size(); t++) {
|
||||
// TODO: Print which tensor fails.
|
||||
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
|
||||
#ifdef VERSION_GE_1_5
|
||||
contiguous_memory =
|
||||
(contiguous_memory ||
|
||||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
|
||||
#endif
|
||||
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
|
||||
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
|
||||
"A tensor was not on the same device as the first tensor");
|
||||
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
|
||||
"Size mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tl.start_tensor_this_launch = 0;
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
|
||||
loc_tensor_info++;
|
||||
|
||||
int chunks_this_tensor =
|
||||
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
|
||||
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
|
||||
// std::cout << chunks_this_tensor << std::endl;
|
||||
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tl.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks_this_tensor - 1);
|
||||
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
|
||||
if (tensors_full || blocks_full || last_chunk) {
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// Reset. The control flow possibilities here make my brain hurt.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks_this_tensor - 1) {
|
||||
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3
|
||||
// << std::endl;
|
||||
loc_tensor_info = 0;
|
||||
tl.start_tensor_this_launch = t + 1;
|
||||
} else {
|
||||
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3
|
||||
// << std::endl;
|
||||
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++)
|
||||
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
|
||||
loc_tensor_info = 1;
|
||||
tl.start_tensor_this_launch = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
382
extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu
Normal file
382
extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu
Normal file
@@ -0,0 +1,382 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
||||
int src_offset) {
|
||||
typedef
|
||||
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
template <typename x_t>
|
||||
struct L2NormFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||
float *output, float *output_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float vals[ILP]; // = {0}; // this probably works too but I want to be
|
||||
// sure...
|
||||
x_t r_x[ILP];
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
vals[i] = 0.f;
|
||||
r_x[i] = 0;
|
||||
}
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_x, x, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
float next = static_cast<float>(r_x[ii]);
|
||||
vals[ii] += next * next;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
float next = static_cast<float>(x[i]);
|
||||
vals[ii] += next * next;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float val = 0.f;
|
||||
for (int i = 0; i < ILP; i++) val += vals[i];
|
||||
|
||||
float final = reduce_block_into_lanes(s_vals, val);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (!isfinite(final))
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
output[blockIdx.x] += final;
|
||||
if (per_tensor)
|
||||
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
|
||||
max_chunks_per_tensor +
|
||||
chunk_idx] = final;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Probably better to template, but since we are not likely to support other
|
||||
// norm
|
||||
template <typename x_t>
|
||||
struct MaxNormFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
|
||||
float *output, float *output_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
__shared__ float s_vals[512];
|
||||
|
||||
float vals[ILP]; // = {0}; // this probably works too but I want to be
|
||||
// sure...
|
||||
x_t r_x[ILP];
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
vals[i] = 0.f;
|
||||
r_x[i] = 0;
|
||||
}
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_x, x, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
float next = static_cast<float>(r_x[ii]);
|
||||
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
float next = static_cast<float>(x[i]);
|
||||
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float val = 0.f;
|
||||
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
|
||||
|
||||
float final = reduce_block_into_lanes_max_op(s_vals, val);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
if (!isfinite(final))
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
|
||||
if (per_tensor)
|
||||
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
|
||||
max_chunks_per_tensor +
|
||||
chunk_idx] = final;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
__global__ void cleanup(float *output, float *output_per_tensor, float *ret,
|
||||
float *ret_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor) {
|
||||
__shared__ float vals[512];
|
||||
|
||||
if (blockIdx.x == 0) {
|
||||
float val = 0;
|
||||
if (threadIdx.x < 320) val = output[threadIdx.x];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0) *ret = sqrt(final);
|
||||
}
|
||||
|
||||
if (per_tensor) {
|
||||
float *output_this_tensor =
|
||||
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
||||
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val += output_this_tensor[i];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
|
||||
float *ret_per_tensor, bool per_tensor,
|
||||
int max_chunks_per_tensor, int norm_type,
|
||||
float alpha, float beta) {
|
||||
__shared__ float vals[512];
|
||||
|
||||
if (blockIdx.x == 0) {
|
||||
float val = 0;
|
||||
if (threadIdx.x < 320) val = output[threadIdx.x];
|
||||
|
||||
if (norm_type == 0) {
|
||||
float final = reduce_block_into_lanes_max_op(vals, val);
|
||||
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
|
||||
} else {
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
|
||||
}
|
||||
}
|
||||
|
||||
if (per_tensor) {
|
||||
float *output_this_tensor =
|
||||
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
||||
|
||||
if (norm_type == 0) {
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));
|
||||
|
||||
float final = reduce_block_into_lanes_max_op(vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
ret_per_tensor[blockIdx.x] =
|
||||
alpha * ret_per_tensor[blockIdx.x] + beta * final;
|
||||
} else {
|
||||
float val = 0;
|
||||
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
|
||||
val += output_this_tensor[i];
|
||||
|
||||
float final = reduce_block_into_lanes(vals, val);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] *
|
||||
ret_per_tensor[blockIdx.x] +
|
||||
beta * final);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python) {
|
||||
bool per_tensor =
|
||||
per_tensor_python.has_value() ? per_tensor_python.value() : false;
|
||||
|
||||
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
|
||||
auto output = at::zeros({320}, float_options);
|
||||
|
||||
at::Tensor output_per_tensor;
|
||||
at::Tensor ret_per_tensor;
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
int max_chunks_per_tensor = -1;
|
||||
|
||||
if (per_tensor) {
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
int max_chunks_this_tensor =
|
||||
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
if (max_chunks_this_tensor > max_chunks_per_tensor)
|
||||
max_chunks_per_tensor = max_chunks_this_tensor;
|
||||
}
|
||||
output_per_tensor =
|
||||
at::zeros({ntensors * max_chunks_per_tensor}, float_options);
|
||||
ret_per_tensor = at::empty({ntensors}, float_options);
|
||||
} else {
|
||||
ret_per_tensor = at::empty({0}, float_options);
|
||||
}
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
L2NormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
per_tensor, max_chunks_per_tensor);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// This involves one more small kernel launches, but will be negligible end to
|
||||
// end. I could get rid of these by hacking the functor + multi tensor harness
|
||||
// with persistence logic, but keeping it simple for now
|
||||
auto ret = at::empty({1}, output.options());
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
|
||||
output.DATA_PTR<float>(),
|
||||
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
ret.DATA_PTR<float>(),
|
||||
per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr, per_tensor,
|
||||
max_chunks_per_tensor);
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
|
||||
}
|
||||
|
||||
// Compute and update grad norm
|
||||
// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
|
||||
// L-2: gn = sqrt(a * gn^2 + b * n^2)
|
||||
// L-inf: gn = a * gn + b * n
|
||||
void multi_tensor_norm_out_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor out,
|
||||
const float alpha, const float beta, const int norm_type) {
|
||||
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
|
||||
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(),
|
||||
"noop flag should be on the same device as tensors");
|
||||
// we don't need global thus uses empty here
|
||||
auto output = at::empty({320}, float_options);
|
||||
|
||||
at::Tensor output_per_tensor;
|
||||
at::Tensor ret_per_tensor;
|
||||
|
||||
int ntensors = tensor_lists[0].size();
|
||||
int max_chunks_per_tensor = -1;
|
||||
|
||||
for (int t = 0; t < ntensors; t++) {
|
||||
int max_chunks_this_tensor =
|
||||
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
|
||||
if (max_chunks_this_tensor > max_chunks_per_tensor)
|
||||
max_chunks_per_tensor = max_chunks_this_tensor;
|
||||
}
|
||||
|
||||
// Although it is single write then read, still need to be zero
|
||||
// Since tailing element also participate cleanup
|
||||
output_per_tensor =
|
||||
at::zeros({ntensors * max_chunks_per_tensor}, float_options);
|
||||
|
||||
if (norm_type == 0) {
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
MaxNormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(), true, max_chunks_per_tensor);)
|
||||
} else {
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
L2NormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(), true, max_chunks_per_tensor);)
|
||||
}
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
// This involves one more small kernel launches, but will be negligible end to
|
||||
// end. I could get rid of these by hacking the functor + multi tensor harness
|
||||
// with persistence logic, but keeping it simple for now
|
||||
auto ret = at::empty({1}, output.options());
|
||||
|
||||
// Adding the following device guard since it happens sometimes that the
|
||||
// tensors are on one device and the cuda stream is on another device which
|
||||
// results in ILLEGAL MEM ACCESS error.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cleanup_v2<<<ntensors, 512, 0, stream>>>(
|
||||
output.DATA_PTR<float>(), output_per_tensor.DATA_PTR<float>(),
|
||||
ret.DATA_PTR<float>(), out.DATA_PTR<float>(), true, max_chunks_per_tensor,
|
||||
norm_type, alpha, beta);
|
||||
|
||||
return;
|
||||
}
|
354
extensions/csrc/cuda/multi_tensor_lamb.cu
Normal file
354
extensions/csrc/cuda/multi_tensor_lamb.cu
Normal file
@@ -0,0 +1,354 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
||||
int src_offset) {
|
||||
typedef
|
||||
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
typedef enum {
|
||||
MOMENT_MODE_0 = 0, // L2 regularization mode
|
||||
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
|
||||
} adamMode_t;
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
||||
int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
at::optional<bool> per_tensor_python);
|
||||
|
||||
using MATH_T = float;
|
||||
|
||||
template <typename T>
|
||||
struct LAMBStage1Functor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
|
||||
const float beta1, const float beta2, const float beta3,
|
||||
const float beta1_correction, const float beta2_correction,
|
||||
const float epsilon, adamMode_t mode, const float decay,
|
||||
const float *global_grad_norm, const float max_global_grad_norm) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
float clipped_global_grad_norm =
|
||||
(*global_grad_norm) > max_global_grad_norm
|
||||
? (*global_grad_norm) / max_global_grad_norm
|
||||
: 1.0f;
|
||||
|
||||
T *g = (T *)tl.addresses[0][tensor_loc];
|
||||
g += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
T *m = (T *)tl.addresses[2][tensor_loc];
|
||||
m += chunk_idx * chunk_size;
|
||||
|
||||
T *v = (T *)tl.addresses[3][tensor_loc];
|
||||
v += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) &&
|
||||
is_aligned(p) && is_aligned(m) && is_aligned(v)) {
|
||||
T l_g[ILP];
|
||||
T l_p[ILP];
|
||||
T l_m[ILP];
|
||||
T l_v[ILP];
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(l_g, g, 0, i_start);
|
||||
if (decay != 0) load_store(l_p, p, 0, i_start);
|
||||
load_store(l_m, m, 0, i_start);
|
||||
load_store(l_v, v, 0, i_start);
|
||||
// unpack
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_g[ii] = l_g[ii];
|
||||
if (decay == 0) {
|
||||
r_p[ii] = MATH_T(0);
|
||||
} else {
|
||||
r_p[ii] = l_p[ii];
|
||||
}
|
||||
r_m[ii] = l_m[ii];
|
||||
r_v[ii] = l_v[ii];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (mode == MOMENT_MODE_0) {
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
// L2 on scaled grad
|
||||
scaled_grad = scaled_grad + decay * r_p[ii];
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = next_m_unbiased / denom;
|
||||
} else {
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
l_p[ii] = r_p[ii];
|
||||
l_m[ii] = r_m[ii];
|
||||
l_v[ii] = r_v[ii];
|
||||
}
|
||||
// store
|
||||
load_store(g, l_p, i_start, 0);
|
||||
load_store(m, l_m, i_start, 0);
|
||||
load_store(v, l_v, i_start, 0);
|
||||
}
|
||||
} else {
|
||||
// see note in multi_tensor_scale_kernel.cu
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
MATH_T r_g[ILP];
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_m[ILP];
|
||||
MATH_T r_v[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
r_g[ii] = g[i];
|
||||
// special ?optimization? for lamb stage 1
|
||||
if (decay == 0) {
|
||||
r_p[ii] = MATH_T(0);
|
||||
} else {
|
||||
r_p[ii] = p[i];
|
||||
}
|
||||
r_m[ii] = m[i];
|
||||
r_v[ii] = v[i];
|
||||
} else {
|
||||
r_g[ii] = MATH_T(0);
|
||||
r_p[ii] = MATH_T(0);
|
||||
r_m[ii] = MATH_T(0);
|
||||
r_v[ii] = MATH_T(0);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
if (mode == MOMENT_MODE_0) {
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
// L2 on scaled grad
|
||||
scaled_grad = scaled_grad + decay * r_p[ii];
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = next_m_unbiased / denom;
|
||||
} else {
|
||||
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
|
||||
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
|
||||
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
|
||||
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
|
||||
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
|
||||
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
|
||||
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
g[i] = r_p[ii];
|
||||
m[i] = r_m[ii];
|
||||
v[i] = r_v[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
|
||||
// It computes new parameter value.
|
||||
template <typename T>
|
||||
struct LAMBStage2Functor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
|
||||
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
|
||||
const float learning_rate, const float decay, bool use_nvlamb) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
MATH_T ratio = learning_rate;
|
||||
// nvlamb: apply adaptive learning rate to all parameters
|
||||
// otherwise, only apply to those with non-zero weight decay
|
||||
if (use_nvlamb || (decay != 0.0)) {
|
||||
float param_norm = per_tensor_param_norm[tensor_num];
|
||||
float update_norm = per_tensor_update_norm[tensor_num];
|
||||
ratio = (update_norm != 0.0f && param_norm != 0.0f)
|
||||
? learning_rate * (param_norm / update_norm)
|
||||
: learning_rate;
|
||||
}
|
||||
|
||||
T *update = (T *)tl.addresses[0][tensor_loc];
|
||||
update += chunk_idx * chunk_size;
|
||||
|
||||
T *p = (T *)tl.addresses[1][tensor_loc];
|
||||
p += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) &&
|
||||
is_aligned(update)) {
|
||||
T r_p[ILP];
|
||||
T r_update[ILP];
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_p, p, 0, i_start);
|
||||
load_store(r_update, update, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_p[ii] = static_cast<MATH_T>(r_p[ii]) -
|
||||
(ratio * static_cast<MATH_T>(r_update[ii]));
|
||||
}
|
||||
load_store(p, r_p, i_start, 0);
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
MATH_T r_p[ILP];
|
||||
MATH_T r_update[ILP];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
r_p[ii] = p[i];
|
||||
r_update[ii] = update[i];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
p[i] = r_p[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
const float lr, const float beta1,
|
||||
const float beta2, const float epsilon,
|
||||
const int step, const int bias_correction,
|
||||
const float weight_decay, const int grad_averaging,
|
||||
const int mode, at::Tensor global_grad_norm,
|
||||
const float max_grad_norm,
|
||||
at::optional<bool> use_nvlamb_python) {
|
||||
using namespace at;
|
||||
// Master weight and 32bit momentum(potentially changing) is not handled by
|
||||
// this So we assume every tensor are all in the same type
|
||||
|
||||
bool use_nvlamb =
|
||||
use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
|
||||
|
||||
// Handle bias correction mode
|
||||
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
bias_correction1 = 1 - std::pow(beta1, step);
|
||||
bias_correction2 = 1 - std::pow(beta2, step);
|
||||
}
|
||||
|
||||
// Handle grad averaging mode
|
||||
float beta3 = 1.0f;
|
||||
if (grad_averaging == 1) beta3 = 1 - beta1;
|
||||
|
||||
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
|
||||
tensor_lists.begin() + 1);
|
||||
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1,
|
||||
tensor_lists.begin() + 2);
|
||||
|
||||
// Compute per tensor param norm
|
||||
auto param_norm_tuple =
|
||||
multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
|
||||
|
||||
// We now in-place modify grad to store update before compute its norm
|
||||
// Generally this is not a issue since people modify grad in step() method all
|
||||
// the time We can also grab list of empty tensor to avoid this, but I'd like
|
||||
// to save space/cpu code
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
|
||||
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
|
||||
beta3, // 1-beta1 or 1 depends on averaging mode
|
||||
bias_correction1, bias_correction2, epsilon,
|
||||
(adamMode_t)mode, weight_decay,
|
||||
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
|
||||
|
||||
// Compute update norms
|
||||
auto update_norm_tuple =
|
||||
multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
|
||||
|
||||
std::vector<std::vector<at::Tensor>> grad_param_list(
|
||||
tensor_lists.begin(), tensor_lists.begin() + 2);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
|
||||
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list,
|
||||
LAMBStage2Functor<scalar_t_0>(),
|
||||
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
|
||||
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
|
||||
lr, weight_decay, use_nvlamb);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
125
extensions/csrc/cuda/multi_tensor_scale_kernel.cu
Normal file
125
extensions/csrc/cuda/multi_tensor_scale_kernel.cu
Normal file
@@ -0,0 +1,125 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
// Another possibility:
|
||||
// #include <torch/all.h>
|
||||
|
||||
#include <assert.h>
|
||||
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
|
||||
#include <sstream>
|
||||
|
||||
#include "multi_tensor_apply.cuh"
|
||||
#include "type_shim.h"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
|
||||
int src_offset) {
|
||||
typedef
|
||||
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
|
||||
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
|
||||
}
|
||||
|
||||
template <typename in_t, typename out_t>
|
||||
struct ScaleFunctor {
|
||||
__device__ __forceinline__ void operator()(int chunk_size,
|
||||
volatile int *noop_gmem,
|
||||
TensorListMetadata<2> &tl,
|
||||
float scale) {
|
||||
// I'd like this kernel to propagate infs/nans.
|
||||
// if(*noop_gmem == 1)
|
||||
// return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
in_t *in = (in_t *)tl.addresses[0][tensor_loc];
|
||||
in += chunk_idx * chunk_size;
|
||||
|
||||
out_t *out = (out_t *)tl.addresses[1][tensor_loc];
|
||||
out += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
bool finite = true;
|
||||
in_t r_in[ILP];
|
||||
out_t r_out[ILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) &&
|
||||
is_aligned(out)) {
|
||||
for (int i_start = threadIdx.x;
|
||||
i_start * ILP < n && i_start * ILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_in, in, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
|
||||
finite = finite && isfinite(r_in[ii]);
|
||||
}
|
||||
// store
|
||||
load_store(out, r_out, i_start, 0);
|
||||
}
|
||||
} else {
|
||||
// Non-divergent exit condition for __syncthreads, not necessary here
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_in[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) r_in[ii] = in[i];
|
||||
}
|
||||
// note for clarification to future michael:
|
||||
// From a pure memory dependency perspective, there's likely no point
|
||||
// unrolling the write loop, since writes just fire off once their LDGs
|
||||
// arrive. Put another way, the STGs are dependent on the LDGs, but not
|
||||
// on each other. There is still compute ILP benefit from unrolling the
|
||||
// loop though.
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
|
||||
finite = finite && isfinite(r_in[ii]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) out[i] = r_out[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!finite)
|
||||
*noop_gmem =
|
||||
1; // Blindly fire off a write. These will race but that's ok.
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float scale) {
|
||||
using namespace at;
|
||||
// The output (downscaled) type is always float.
|
||||
// If build times suffer, think about where to put this dispatch,
|
||||
// and what logic should be moved out of multi_tensor_apply.
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
|
||||
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
ScaleFunctor<scalar_t_0, scalar_t_1>(),
|
||||
scale);))
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// AT_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
167
extensions/csrc/cuda/multi_tensor_sgd_kernel.cu
Normal file
167
extensions/csrc/cuda/multi_tensor_sgd_kernel.cu
Normal file
@@ -0,0 +1,167 @@
|
||||
// modified from
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <assert.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "compat.h"
|
||||
#include "multi_tensor_apply.cuh"
|
||||
|
||||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
/**
|
||||
* Perform fused SGD on multiple buffers
|
||||
* N: number of tensors
|
||||
* tl[0] : gradients
|
||||
* tl[1] : weights
|
||||
* tl[2] : momentum buffers
|
||||
* tl[3] : fp16 weights (if appropriate)
|
||||
* wd : weight_decay (scalar)
|
||||
* momentum : momentum (scalar)
|
||||
* dampening : momentum dampening (scalar)
|
||||
* lr : learning rate (scalar)
|
||||
* nesterov : enable nesterov (bool)
|
||||
* first run : necessary for proper momentum handling & init
|
||||
* wd_after_momentum : apply weight decay _after_ momentum instead of before
|
||||
**/
|
||||
template <typename T_grad, typename T_weight>
|
||||
struct SGDFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,
|
||||
float wd, float momentum, float dampening, float lr, bool nesterov,
|
||||
bool first_run, bool wd_after_momentum, float scale) {
|
||||
// Early exit if we don't need to do anything
|
||||
if (*noop_gmem) return;
|
||||
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.sizes[tensor_loc];
|
||||
|
||||
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
|
||||
grad_in += chunk_idx * chunk_size;
|
||||
|
||||
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
|
||||
weight_in += chunk_idx * chunk_size;
|
||||
|
||||
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
|
||||
mom_in += chunk_idx * chunk_size;
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
|
||||
// Non-divergent exit condition for the __syncthreads
|
||||
float incoming_grads[ILP];
|
||||
float incoming_weights[ILP];
|
||||
float incoming_moms[ILP];
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * ILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
incoming_grads[ii] = 0;
|
||||
incoming_weights[ii] = 0;
|
||||
incoming_moms[ii] = 0;
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
|
||||
incoming_weights[ii] = static_cast<float>(weight_in[i]);
|
||||
incoming_moms[ii] = static_cast<float>(mom_in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// note for clarification to future michael:
|
||||
// From a pure memory dependency perspective, there's likely no point unrolling
|
||||
// the write loop, since writes just fire off once their LDGs arrive.
|
||||
// Put another way, the STGs are dependent on the LDGs, but not on each other.
|
||||
// There is still compute ILP benefit from unrolling the loop though.
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if (i < n && i < chunk_size) {
|
||||
// apply weight decay before momentum if necessary
|
||||
if (wd != 0.f && !wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
|
||||
if (momentum != 0.f) {
|
||||
if (!first_run)
|
||||
incoming_moms[ii] = incoming_moms[ii] * momentum +
|
||||
(1.f - dampening) * incoming_grads[ii];
|
||||
else // initialize momentums to current incoming grads
|
||||
incoming_moms[ii] = incoming_grads[ii];
|
||||
|
||||
if (nesterov)
|
||||
incoming_grads[ii] += momentum * incoming_moms[ii];
|
||||
else
|
||||
incoming_grads[ii] = incoming_moms[ii];
|
||||
}
|
||||
|
||||
// Apply WD after momentum if desired
|
||||
if (wd != 0.f && wd_after_momentum)
|
||||
incoming_grads[ii] += wd * incoming_weights[ii];
|
||||
|
||||
// adjust the weight and write out
|
||||
weight_in[i] += (-lr * incoming_grads[ii]);
|
||||
|
||||
// also write out the new momentum
|
||||
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists,
|
||||
float wd, float momentum, float dampening, float lr,
|
||||
bool nesterov, bool first_run,
|
||||
bool wd_after_momentum, float scale) {
|
||||
auto num_tensors = tensor_lists.size();
|
||||
auto grad_type = tensor_lists[0][0].scalar_type();
|
||||
auto weight_type = tensor_lists[1][0].scalar_type();
|
||||
|
||||
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
|
||||
"expected noop flag to be on the same device as tensors");
|
||||
|
||||
// We have 3 possibilities to handle here, in terms of
|
||||
// grad_type, param_type, momentum_type
|
||||
// 1. fp16, fp16, fp16
|
||||
// 2. fp32, fp32, fp32
|
||||
// 3. fp16, fp32, fp32
|
||||
// It's easier to hardcode these possibilities than to use
|
||||
// switches etc. to handle the cross-product of cases where
|
||||
// we don't want the majority of them.
|
||||
|
||||
// Case 1. fp16, fp16, fp16, No
|
||||
if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Half && num_tensors == 3) {
|
||||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<at::Half, at::Half>(), wd, momentum,
|
||||
dampening, lr, nesterov, first_run, wd_after_momentum,
|
||||
scale);
|
||||
}
|
||||
// Case 2. fp32, fp32, fp32
|
||||
else if (grad_type == at::ScalarType::Float &&
|
||||
weight_type == at::ScalarType::Float && num_tensors == 3) {
|
||||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<float, float>(), wd, momentum, dampening,
|
||||
lr, nesterov, first_run, wd_after_momentum, scale);
|
||||
}
|
||||
// Case 3. fp16, fp32, fp32
|
||||
else if (grad_type == at::ScalarType::Half &&
|
||||
weight_type == at::ScalarType::Float && num_tensors == 3) {
|
||||
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
SGDFunctor<at::Half, float>(), wd, momentum,
|
||||
dampening, lr, nesterov, first_run, wd_after_momentum,
|
||||
scale);
|
||||
} else {
|
||||
AT_ERROR(
|
||||
"multi_tensor_sgd only supports some combinations of gradient & weight "
|
||||
"types. Given: ",
|
||||
"gradient: ", grad_type, ", weight: ", weight_type,
|
||||
", num_lists: ", num_tensors);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
70
extensions/csrc/cuda/scaled_masked_softmax.cpp
Normal file
70
extensions/csrc/cuda/scaled_masked_softmax.cpp
Normal file
@@ -0,0 +1,70 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
|
||||
|
||||
return fwd_cuda(input, mask, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results, float scale_factor) {
|
||||
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches,
|
||||
attn_heads);
|
||||
}
|
||||
|
||||
} // end namespace scaled_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
|
||||
m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
|
||||
m.def("get_batch_per_block",
|
||||
&multihead_attn::fused_softmax::scaled_masked_softmax::
|
||||
get_batch_per_block,
|
||||
"Return Batch per block size.");
|
||||
}
|
538
extensions/csrc/cuda/scaled_masked_softmax.h
Normal file
538
extensions/csrc/cuda/scaled_masked_softmax.h
Normal file
@@ -0,0 +1,538 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*((half2 *)dst) = *((half2 *)src);
|
||||
}
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T
|
||||
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
||||
unsigned int mask = 0xffffffff) {
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
||||
template <typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional
|
||||
* features 1) input scaling 2) Explicit masking
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_forward(
|
||||
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
|
||||
int micro_batch_size, int element_count, int pad_batches) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch =
|
||||
(blockDim.y *
|
||||
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
|
||||
threadIdx.y) *
|
||||
WARP_BATCH;
|
||||
int pad_first_batch = 0;
|
||||
if (pad_batches != 1) { // bert style
|
||||
pad_first_batch =
|
||||
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
|
||||
WARP_BATCH;
|
||||
} else { // gpt2 style
|
||||
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
}
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
int itr_idx = i * element_count + it * WARP_SIZE;
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
|
||||
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (temp_mask[element] != 1) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -10000.0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] =
|
||||
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH]{0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count + it * WARP_SIZE, out);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
__global__ void scaled_masked_softmax_warp_backward(
|
||||
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||
int micro_batch_size, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
|
||||
// gridDim/blockIdx = (seq_len, attn_heads, batches)
|
||||
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset =
|
||||
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : element_count;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_grad, grad + i * element_count + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_output, output + i * element_count + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
grad_reg[i][it + element] =
|
||||
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] =
|
||||
(output_t)(scale * (grad_reg[i][it + element] -
|
||||
output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
gradInput + i * element_count + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // end of anonymous namespace
|
||||
|
||||
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
constexpr int threads_per_block = 128;
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
|
||||
return batches_per_block;
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
|
||||
const uint8_t *mask,
|
||||
const input_t scale,
|
||||
int query_seq_len, int key_seq_len,
|
||||
int batches, int attn_heads,
|
||||
int pad_batches) {
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);
|
||||
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
|
||||
input_t *grad,
|
||||
const input_t *output,
|
||||
const acc_t scale,
|
||||
int query_seq_len, int key_seq_len,
|
||||
int batches, int attn_heads) {
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
|
||||
if (key_seq_len == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(key_seq_len);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int batch_count = batches * attn_heads * query_seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
int blocks = batch_count / batches_per_block;
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count, key_seq_len);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
89
extensions/csrc/cuda/scaled_masked_softmax_cuda.cu
Normal file
89
extensions/csrc/cuda/scaled_masked_softmax_cuda.cu
Normal file
@@ -0,0 +1,89 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "scaled_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_masked_softmax {
|
||||
|
||||
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
|
||||
int attn_heads) {
|
||||
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
|
||||
}
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
|
||||
float scale_factor) {
|
||||
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
|
||||
// seq_len]
|
||||
const int batches = input.size(0);
|
||||
const int pad_batches = mask.size(0);
|
||||
const int attn_heads = input.size(1);
|
||||
const int query_seq_len = input.size(2);
|
||||
const int key_seq_len = input.size(3);
|
||||
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
|
||||
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
|
||||
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
|
||||
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results = torch::empty(
|
||||
{batches, attn_heads, query_seq_len, key_seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
void* mask_ptr = static_cast<void*>(mask.data_ptr());
|
||||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
|
||||
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr),
|
||||
reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
|
||||
query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
|
||||
// seq_len]
|
||||
const int batches = output_grads.size(0);
|
||||
const int attn_heads = output_grads.size(1);
|
||||
const int query_seq_len = output_grads.size(2);
|
||||
const int key_seq_len = output_grads.size(3);
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
// Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
|
||||
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
|
||||
|
||||
// backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
} // namespace scaled_masked_softmax
|
||||
} // namespace fused_softmax
|
||||
} // namespace multihead_attn
|
54
extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp
Normal file
54
extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp
Normal file
@@ -0,0 +1,54 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results,
|
||||
float scale_factor);
|
||||
|
||||
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
|
||||
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
||||
(input.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return fwd_cuda(input, scale_factor);
|
||||
}
|
||||
|
||||
torch::Tensor bwd(torch::Tensor const& output_grads,
|
||||
torch::Tensor const& softmax_results, float scale_factor) {
|
||||
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
|
||||
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
|
||||
|
||||
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
|
||||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
|
||||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
|
||||
"Only fp16 and bf16 are supported");
|
||||
|
||||
return bwd_cuda(output_grads, softmax_results, scale_factor);
|
||||
}
|
||||
|
||||
} // end namespace scaled_upper_triang_masked_softmax
|
||||
} // end namespace fused_softmax
|
||||
} // end namespace multihead_attn
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Forward.");
|
||||
m.def("backward",
|
||||
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
|
||||
"Self Multihead Attention scaled, time masked softmax -- Backward.");
|
||||
}
|
600
extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h
Normal file
600
extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h
Normal file
@@ -0,0 +1,600 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
|
||||
const uint8_t *src) {
|
||||
*((half2 *)dst) = *((half2 *)src);
|
||||
}
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_zero_vector(Datatype *dst);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst) {
|
||||
*dst = 0.0;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst) {
|
||||
*((float2 *)dst) = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) {
|
||||
*dst = 0.0;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) {
|
||||
*((float2 *)dst) = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
|
||||
int log2_ceil(int value) {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < value) ++log2_value;
|
||||
return log2_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Add {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Max {
|
||||
__device__ __forceinline__ T operator()(T a, T b) const {
|
||||
return a < b ? b : a;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T
|
||||
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
|
||||
unsigned int mask = 0xffffffff) {
|
||||
#if CUDA_VERSION >= 9000
|
||||
return __shfl_xor_sync(mask, value, laneMask, width);
|
||||
#else
|
||||
return __shfl_xor(value, laneMask, width);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
|
||||
template <typename> class ReduceOp>
|
||||
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
|
||||
ReduceOp<acc_t> r;
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
||||
sum[i] = r(sum[i], b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Extended softmax (from native aten pytorch) with following additional
|
||||
* features 1) input scaling 2) Implicit time (diagonal masking)
|
||||
*/
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
|
||||
output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size,
|
||||
int stride, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_forward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch =
|
||||
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
||||
blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
int warp_iteration_limit =
|
||||
(local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
|
||||
// load data from global memory
|
||||
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
|
||||
input_t temp_data[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_data, src + i * element_count * stride + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if ((element_index + element) < batch_element_count) {
|
||||
elements[i][it + element] = (acc_t)temp_data[element] * scale;
|
||||
} else {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute max_value
|
||||
acc_t max_value[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
max_value[i] = elements[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
max_value[i] =
|
||||
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
|
||||
|
||||
acc_t sum[WARP_BATCH]{0.0f};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
if (it < warp_iteration_limit) {
|
||||
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
|
||||
if (element_index < local_seq) {
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < local_seq) {
|
||||
out[element] = elements[i][it + element] / sum[i];
|
||||
} else {
|
||||
out[element] = 0;
|
||||
}
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count * stride + it * WARP_SIZE, out);
|
||||
} else if (element_index < element_count) {
|
||||
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
dst + i * element_count * stride + it * WARP_SIZE);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t,
|
||||
int log2_elements>
|
||||
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
|
||||
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
|
||||
int micro_batch_size, int stride, int element_count) {
|
||||
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||
// warp_size of method warp_softmax_backward_kernel.
|
||||
constexpr int next_power_of_two = 1 << log2_elements;
|
||||
constexpr int WARP_SIZE =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
|
||||
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
|
||||
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
|
||||
|
||||
int first_batch =
|
||||
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
|
||||
blockIdx.x;
|
||||
int local_seq = blockIdx.x + 1;
|
||||
|
||||
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
|
||||
// many batches have to computed within this WARP.
|
||||
int local_batches = micro_batch_size - first_batch;
|
||||
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
|
||||
|
||||
// there might be multiple batches per warp. compute the index within the
|
||||
// batch
|
||||
int local_idx = threadIdx.x;
|
||||
|
||||
// the first element to process by the current thread
|
||||
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
|
||||
grad += thread_offset;
|
||||
output += thread_offset;
|
||||
gradInput += thread_offset;
|
||||
|
||||
// load data from global memory
|
||||
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
|
||||
input_t temp_grad[ELEMENTS_PER_LDG_STG];
|
||||
input_t temp_output[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
|
||||
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < batch_element_count) {
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
|
||||
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
|
||||
temp_output, output + i * element_count * stride + it * WARP_SIZE);
|
||||
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
output_reg[i][it + element] = (acc_t)temp_output[element];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
if (element_index + element < batch_element_count) {
|
||||
grad_reg[i][it + element] =
|
||||
(acc_t)temp_grad[element] * output_reg[i][it + element];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
acc_t sum[WARP_BATCH];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
sum[i] = grad_reg[i][0];
|
||||
#pragma unroll
|
||||
for (int it = 1; it < WARP_ITERATIONS; ++it) {
|
||||
sum[i] += grad_reg[i][it];
|
||||
}
|
||||
}
|
||||
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
|
||||
|
||||
// store result
|
||||
#pragma unroll
|
||||
for (int i = 0; i < WARP_BATCH; ++i) {
|
||||
if (i >= local_batches) break;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
|
||||
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
|
||||
if (element_index < element_count) {
|
||||
// compute gradients
|
||||
output_t out[ELEMENTS_PER_LDG_STG];
|
||||
#pragma unroll
|
||||
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
|
||||
out[element] =
|
||||
(output_t)(scale * (grad_reg[i][it + element] -
|
||||
output_reg[i][it + element] * sum[i]));
|
||||
}
|
||||
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
|
||||
gradInput + i * element_count * stride + it * WARP_SIZE, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // end of anonymous namespace
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_forward(
|
||||
output_t *dst, const input_t *src, const input_t scale,
|
||||
int softmax_elements, int softmax_elements_stride, int attn_batches) {
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_forward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
|
||||
acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
dst, src, scale, batch_count, softmax_elements_stride,
|
||||
softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename input_t, typename output_t, typename acc_t>
|
||||
void dispatch_scaled_upper_triang_masked_softmax_backward(
|
||||
output_t *grad_input, input_t *grad, const input_t *output,
|
||||
const acc_t scale, int softmax_elements, int softmax_elements_stride,
|
||||
int attn_batches) {
|
||||
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
|
||||
if (softmax_elements == 0) {
|
||||
return;
|
||||
} else {
|
||||
int log2_elements = log2_ceil(softmax_elements);
|
||||
const int next_power_of_two = 1 << log2_elements;
|
||||
int seq_len = softmax_elements;
|
||||
int batch_count = attn_batches * seq_len;
|
||||
|
||||
// This value must match the WARP_SIZE constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int warp_size =
|
||||
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
|
||||
|
||||
// This value must match the WARP_BATCH constexpr value computed inside
|
||||
// softmax_warp_backward.
|
||||
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
|
||||
|
||||
// use 128 threads per block to maximimize gpu utilization
|
||||
constexpr int threads_per_block = 128;
|
||||
|
||||
int warps_per_block = (threads_per_block / warp_size);
|
||||
int batches_per_block = warps_per_block * batches_per_warp;
|
||||
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
|
||||
|
||||
int blocks_per_seq = attn_batches / batches_per_block;
|
||||
dim3 blocks(seq_len, blocks_per_seq, 1);
|
||||
dim3 threads(warp_size, warps_per_block, 1);
|
||||
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
|
||||
switch (log2_elements) {
|
||||
case 0: // 1
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 0>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 1: // 2
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 1>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 2: // 4
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 2>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 3: // 8
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 3>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 4: // 16
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 4>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 5: // 32
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 5>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 6: // 64
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 6>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 7: // 128
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 7>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 8: // 256
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 8>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 9: // 512
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 9>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 10: // 1024
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 10>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
case 11: // 2048
|
||||
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
|
||||
acc_t, 11>
|
||||
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
grad_input, grad, output, scale, batch_count,
|
||||
softmax_elements_stride, softmax_elements);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,75 @@
|
||||
/*This code from NVIDIA Megatron:
|
||||
* with minor changes. */
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_profiler_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "scaled_upper_triang_masked_softmax.h"
|
||||
#include "type_shim.h"
|
||||
|
||||
namespace multihead_attn {
|
||||
namespace fused_softmax {
|
||||
namespace scaled_upper_triang_masked_softmax {
|
||||
|
||||
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
|
||||
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
|
||||
const int attn_batches = input.size(0);
|
||||
const int seq_len = input.size(1);
|
||||
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
|
||||
|
||||
// Output
|
||||
auto act_options = input.options().requires_grad(false);
|
||||
torch::Tensor softmax_results =
|
||||
torch::empty({attn_batches, seq_len, seq_len}, act_options);
|
||||
|
||||
// Softmax Intermediate Result Ptr
|
||||
void* input_ptr = static_cast<void*>(input.data_ptr());
|
||||
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
|
||||
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
input.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_forward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t,
|
||||
float>(
|
||||
reinterpret_cast<scalar_t*>(softmax_results_ptr),
|
||||
reinterpret_cast<const scalar_t*>(input_ptr), scale_factor, seq_len,
|
||||
seq_len, attn_batches););
|
||||
return softmax_results;
|
||||
}
|
||||
|
||||
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
|
||||
torch::Tensor const& softmax_results_,
|
||||
float scale_factor) {
|
||||
auto output_grads = output_grads_.contiguous();
|
||||
auto softmax_results = softmax_results_.contiguous();
|
||||
|
||||
// output grads is a 3d tensor with dimensions [attn_batches, seq_len,
|
||||
// seq_len]
|
||||
const int attn_batches = output_grads.size(0);
|
||||
const int seq_len = output_grads.size(1);
|
||||
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
|
||||
|
||||
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
|
||||
|
||||
// Softmax Grad
|
||||
DISPATCH_HALF_AND_BFLOAT(
|
||||
output_grads_.scalar_type(),
|
||||
"dispatch_scaled_upper_triang_masked_softmax_backward",
|
||||
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t,
|
||||
float>(
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t*>(output_grads_ptr),
|
||||
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
|
||||
scale_factor, seq_len, seq_len, attn_batches););
|
||||
|
||||
// backward pass is completely in-place
|
||||
return output_grads;
|
||||
}
|
||||
} // namespace scaled_upper_triang_masked_softmax
|
||||
} // namespace fused_softmax
|
||||
} // namespace multihead_attn
|
279
extensions/csrc/cuda/type_shim.h
Normal file
279
extensions/csrc/cuda/type_shim.h
Normal file
@@ -0,0 +1,279 @@
|
||||
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
|
||||
/* Copyright 2020 The Microsoft DeepSpeed Team
|
||||
Copyright NVIDIA/apex
|
||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||
Licensed under the MIT License.
|
||||
*/
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "compat.h"
|
||||
|
||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch (TYPEIN) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_in = float; \
|
||||
switch (TYPEOUT) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_out = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||
} \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_in = at::Half; \
|
||||
using scalar_t_out = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t_in = at::BFloat16; \
|
||||
using scalar_t_out = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||
}
|
||||
|
||||
// Forward/backward compatiblity hack around
|
||||
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
||||
// pending more future-proof guidance from upstream.
|
||||
// struct TypeShim
|
||||
// {
|
||||
// const at::Type& payload;
|
||||
// TypeShim(const at::Type& type) : payload(type) {}
|
||||
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
||||
// operator const at::Type&(){ return payload; };
|
||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||
// //operator at::ScalarType(){ return payload.; };
|
||||
// };
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Byte: { \
|
||||
using scalar_t_##LEVEL = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Double: { \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Double: { \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: { \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
||||
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Float && \
|
||||
PTYPE == at::ScalarType::Half) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Half && \
|
||||
PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
|
||||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Float && \
|
||||
PTYPE == at::ScalarType::BFloat16) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::BFloat16 && \
|
||||
PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = at::BFloat16; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::BFloat16 && \
|
||||
PTYPE == at::ScalarType::BFloat16) { \
|
||||
using g_scalar_t_##LEVEL = at::BFloat16; \
|
||||
using p_scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||
"'"); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||
T *x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||
T *x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize =
|
||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64) {
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32) {
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final =
|
||||
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
|
||||
if (share_result) {
|
||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
Reference in New Issue
Block a user