[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -0,0 +1,49 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <torch/extension.h>
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale);
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd, float momentum, float dampening, float lr,
bool nesterov, bool first_run,
bool wd_after_momentum, float scale);
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction, const float weight_decay,
const float div_scale);
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int bias_correction,
const float weight_decay, const int grad_averaging,
const int mode, at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
}

View File

@@ -0,0 +1,10 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif

View File

@@ -0,0 +1,446 @@
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"
#include <math.h>
#include <omp.h>
#include <string.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
// C++ interface
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
}
AVX_Data momentum_4;
this->simd_load(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
AVX_Data variance_4;
this->simd_load(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
AVX_Data param_4;
this->simd_load(param_half_precision, _params + i, params_cast_h + i,
param_4);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
}
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data =
SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data =
SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
if (_weight_decay > 0 && _adamw_mode) {
param_4.data =
SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
}
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
this->simd_store(param_half_precision, _params + i, params_cast_h + i,
param_4);
this->simd_store(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
this->simd_store(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
if (loss_scale > 0) {
grad /= loss_scale;
}
float param =
param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum =
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
float variance = variance_half_precision ? (float)variance_cast_h[k]
: _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
if (param_half_precision)
params_cast_h[k] = (__half)param;
else
_params[k] = param;
if (momentum_half_precision)
momentum_cast_h[k] = (__half)(momentum);
else
_exp_avg[k] = momentum;
if (variance_half_precision)
variance_cast_h[k] = (__half)(variance);
else
_exp_avg_sq[k] = variance;
}
}
}
}
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
AVX_Data grad_4[4];
AVX_Data momentum_4[4];
AVX_Data variance_4[4];
AVX_Data param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_store(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
}
}
}
#endif
if (_param_size > rounded_size)
Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4;
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
AVX_Data grad_4[8];
AVX_Data momentum_4[8];
AVX_Data variance_4[8];
AVX_Data param_4[8];
#pragma unroll 8
for (int j = 0; j < 8; j++) {
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_store(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
}
}
}
#endif
if (_param_size > rounded_size)
Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float *params_ptr = (float *)params_c.data_ptr();
float *grads_ptr = (float *)grads_c.data_ptr();
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
(exp_avg.options().dtype() == at::kHalf),
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<Adam_Optimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &Adam_Optimizer::step);
}

View File

@@ -0,0 +1,185 @@
/*
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#pragma once
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <torch/extension.h>
#if (__x86_64__ || __i386__)
#include <cpuid.h>
#include <x86intrin.h>
#endif
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
#if defined(__AVX512__)
#define SIMD_WIDTH 16
#define INTV __m256i
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
#define INTV __m128i
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#endif
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#elif defined(__AVX256__) or defined(__AVX2__)
__m256 data;
#endif
// float data_f[16];
};
#endif
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);
class Adam_Optimizer {
public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~Adam_Optimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
data.data = SIMD_LOAD_HALF(h_ptr);
} else {
data.data = SIMD_LOAD(ptr);
}
}
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
SIMD_STORE_HALF(h_ptr, data.data);
} else {
SIMD_STORE(ptr, data.data);
}
}
#endif
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
};

View File

@@ -0,0 +1,312 @@
/* Copyright 2021 The LightSeq Team
Copyright Tencent/TurboTransformers
This block_reduce_n is adapted from Tencent/TurboTransformers
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
enum class ReduceType { kMax = 0, kSum };
const unsigned int WARP_REDUCE_MASK = 0xffffffff;
const float REDUCE_FLOAT_INF_NEG = -100000000.f;
const float REDUCE_FLOAT_INF_POS = 100000000.f;
const unsigned int WARP_REDUCE_SIZE = 32;
template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) {
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
val = warpReduceSum<T>(val);
return val;
}
template <ReduceType Rtype, int Num>
__inline__ __device__ void blockReduce(float *pval);
// use template to make code more concise
template <ReduceType Rtype, int Num>
__inline__ __device__ void warpReduce(float *pval);
// static
template <>
__inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32));
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32));
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32));
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32));
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32));
}
template <>
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
float val0_tmp, val1_tmp;
#define WarpReduceMaxOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
*(pval) = max(val0_tmp, *(pval)); \
*(pval + 1) = max(val1_tmp, *(pval + 1));
WarpReduceMaxOneStep(16, 32);
WarpReduceMaxOneStep(8, 32);
WarpReduceMaxOneStep(4, 32);
WarpReduceMaxOneStep(2, 32);
WarpReduceMaxOneStep(1, 32);
#undef WarpReduceMaxOneStep
}
template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32);
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32);
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32);
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32);
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32);
}
/*
* Unorll for loop for warpreduce to
* imporve instruction issue efficiency
* ElemX means there are X numbers to be summed
*/
template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
float val0_tmp, val1_tmp;
#define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
*(pval + 0) += val0_tmp; \
*(pval + 1) += val1_tmp
WarpReduceSumOneStep(16, 32);
WarpReduceSumOneStep(8, 32);
WarpReduceSumOneStep(4, 32);
WarpReduceSumOneStep(2, 32);
WarpReduceSumOneStep(1, 32);
#undef WarpReduceSumOneStep
}
template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
#define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
*(pval + 0) += val0_tmp; \
*(pval + 1) += val1_tmp; \
*(pval + 2) += val2_tmp; \
*(pval + 3) += val3_tmp
WarpReduceSumOneStep(16, 32);
WarpReduceSumOneStep(8, 32);
WarpReduceSumOneStep(4, 32);
WarpReduceSumOneStep(2, 32);
WarpReduceSumOneStep(1, 32);
#undef WarpReduceSumOneStep
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kSum, 1>(float *pval) {
const int num = 1;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kSum, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = 0.f;
}
}
warpReduce<ReduceType::kSum, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kSum, 2>(float *pval) {
const int num = 2;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kSum, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = 0.f;
}
}
warpReduce<ReduceType::kSum, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kSum, 4>(float *pval) {
const int num = 4;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kSum, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = 0.f;
}
}
warpReduce<ReduceType::kSum, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kMax, 1>(float *pval) {
const int num = 1;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kMax, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = REDUCE_FLOAT_INF_NEG;
}
}
warpReduce<ReduceType::kMax, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kMax, 2>(float *pval) {
const int num = 1;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kMax, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = REDUCE_FLOAT_INF_NEG;
}
}
warpReduce<ReduceType::kMax, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
const int num = 1;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kMax, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = REDUCE_FLOAT_INF_NEG;
}
}
warpReduce<ReduceType::kMax, num>(pval);
}

View File

@@ -0,0 +1,141 @@
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include "compat.h"
namespace {
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert(input.sizes()[i + idiff] == normalized_shape[i]);
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
at::Tensor beta) {
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim)
.equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input, normalized_shape, n1, n2);
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta);
}
} // namespace
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *input, int n1, int n2,
at::IntArrayRef normalized_shape, at::Tensor *gamma,
at::Tensor *beta, double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output =
at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean =
at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, normalized_shape,
&gamma, &beta, epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
at::Tensor *invvar, at::Tensor *input, int n1,
int n2, at::IntArrayRef normalized_shape,
at::Tensor *gamma, at::Tensor *beta,
double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input,
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)");
}

View File

@@ -0,0 +1,683 @@
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <cuda.h>
#include <cuda_runtime.h>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include "type_shim.h"
template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template <typename U>
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
U& mu, U& sigma2, U& count) {
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA * mu + nB * muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template <typename T, typename U>
__device__ void cuWelfordMuSigma2(const T* __restrict__ vals, const int n1,
const int n2, const int i1, U& mu, U& sigma2,
U* buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu = U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1 * n2;
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l + k]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U* ubuf = (U*)buf;
U* ibuf = (U*)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2 * threadIdx.y];
U sigma2B = ubuf[2 * threadIdx.y + 1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1] / U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2 / U(n2), 0);
}
}
}
template <>
__device__ void cuWelfordMuSigma2(const at::Half* __restrict__ vals,
const int n1, const int n2, const int i1,
float& mu, float& sigma2, float* buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu = float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1 * n2;
int l = 8 * thrx;
if ((((size_t)lvals) & 3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2*)(lvals + l + k)));
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* ubuf = (float*)buf;
float* ibuf = (float*)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2 * threadIdx.y];
float sigma2B = ubuf[2 * threadIdx.y + 1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1] / float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2 / float(n2), 0);
}
}
}
template <typename U>
U rsqrt(U v) {
return U(1) / sqrt(v);
}
template <>
float rsqrt(float v) {
return rsqrtf(v);
}
template <>
double rsqrt(double v) {
return rsqrt(v);
}
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory<float> {
__device__ float* getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
} // namespace
template <typename T, typename U, typename V>
__global__ void cuApplyLayerNorm(V* __restrict__ output_vals,
U* __restrict__ mean, U* __restrict__ invvar,
const T* __restrict__ vals, const int n1,
const int n2, const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu, sigma2;
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);
const T* lvals = vals + i1 * n2;
V* ovals = output_vals + i1 * n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
}
} else {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
}
}
template <typename T, typename U, typename V>
__device__ void cuLoadWriteStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end, const int n2,
const U* __restrict__ mean, const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] =
curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template <typename T, typename U, typename V>
__device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U* warp_buf1, U* warp_buf2,
const T* input, const V* dout, const int i1_end, const int n2,
const U* __restrict__ mean, const U* __restrict__ invvar) {
int i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputePartGradGammaBeta(
const V* __restrict__ dout, const T* __restrict__ input, const int n1,
const int n2, const U* __restrict__ mean, const U* __restrict__ invvar,
U epsilon, U* part_grad_gamma, U* part_grad_beta) {
const int numsegs_n1 =
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one =
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off =
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
// blockDim.y + (blockDim.y -
// 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off,
row_stride, warp_buf1, warp_buf2, input, dout,
i1_end, n2, mean, invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename U, typename V>
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size, const int n1,
const int n2, V* grad_gamma,
V* grad_beta) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U, typename V>
__global__ void cuComputeGradInput(const V* __restrict__ dout,
const T* __restrict__ input, const int n1,
const int n2, const U* __restrict__ mean,
const U* __restrict__ invvar, U epsilon,
const V* gamma, T* grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1 * n2;
const V* k_dout = dout + i1 * n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x / 2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
}
template <typename T, typename U, typename V>
void HostApplyLayerNorm(V* output, U* mean, U* invvar, const T* input, int n1,
int n2, double epsilon, const V* gamma, const V* beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32, 4, 1);
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* beta, double epsilon) {
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
HostApplyLayerNorm(output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(), invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_in>(), n1, n2, epsilon,
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);)
}
template <typename T, typename U, typename V>
void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar,
at::Tensor* input, int n1, int n2, const V* gamma,
const V* beta, double epsilon, T* grad_input,
V* grad_gamma, V* grad_beta) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32, 4, 1);
const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1);
const int nshared2_a =
2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty(
{part_size, n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon),
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32, 8, 1);
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size,
n1, n2, grad_gamma, grad_beta);
}
// compute grad_input
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32, 4, 1);
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon), gamma,
grad_input);
}
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean,
at::Tensor* invvar, at::Tensor* input, int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* beta,
double epsilon, at::Tensor* grad_input,
at::Tensor* grad_gamma, at::Tensor* grad_beta) {
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(),
"cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(
dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(), input, n1, n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);)
}

View File

@@ -0,0 +1,97 @@
#include <torch/extension.h>
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx);
}
torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx);
}
torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask,
dest_idx);
}
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens,
logits, mask, dest_idx);
}
torch::Tensor moe_cumsum(torch::Tensor mask) {
CHECK_INPUT(mask);
return cumsum_sub_one_in_dim0(mask);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0");
m.def("dispatch_forward", &moe_dispatch_forward,
"Forward operation in MoE dispatch function");
m.def("dispatch_backward", &moe_dispatch_backward,
"Backward operation in MoE dispatch function");
m.def("combine_forward", &moe_combine_forward,
"Combine operation in MoE combine function");
m.def("combine_backward", &moe_combine_backward,
"Combine operation in MoE combine function");
}

View File

@@ -0,0 +1,659 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row + idx, pack);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, pack);
BlockStore(ts_store).Store(src_row + idx, pack);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row1 + idx, pack);
BlockStore(ts_store).Store(dst_row2 + idx, pack);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack1[pack_size], pack2[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] += pack2[i];
}
BlockStore(ts_store).Store(src_row + idx, pack1);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] *= weight;
}
BlockStore(ts_store).Store(dst_row + idx, pack);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *weight_grad, const T weight, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens[pack_size];
float thread_sum = 0;
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row + idx, tokens);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum += grad[i] * tokens[i];
grad[i] *= weight;
}
BlockStore(ts_store).Store(src_row + idx, grad);
}
blockReduce<ReduceType::kSum, 1>(&thread_sum);
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack1[pack_size], pack2[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row1 + idx, pack1);
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
}
BlockStore(ts_store).Store(dst_row + idx, pack1);
}
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1,
T *weight_grad2, const T weight1,
const T weight2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size],
sgrad2[pack_size];
float thread_sum[2] = {0, 0};
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum[0] += grad[i] * tokens1[i];
thread_sum[1] += grad[i] * tokens2[i];
sgrad1[i] = weight1 * grad[i];
sgrad2[i] = weight2 * grad[i];
}
BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
}
blockReduce<ReduceType::kSum, 2>(thread_sum);
if (threadIdx.x == 0)
*weight_grad1 = static_cast<T>(thread_sum[0]);
else if (threadIdx.x == 1)
*weight_grad2 = static_cast<T>(thread_sum[1]);
}
// DISPATCH KERNELS --------------------------------
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols);
else if (indicator1 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row1, cols);
else if (indicator2 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols);
else if (indicator1 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row1, cols);
else if (indicator2 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
int *mask1, int *mask2, int *dest1,
int *dest2, const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>(
batch_tokens + (row * h), expert_input + (dest1[row] * h),
expert_input + (dest2[row] * h), h, mask1[row], indicator2);
}
template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
int *mask2, int *dest1, int *dest2,
const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>(
tokens_grad + (row * h), expert_grad + (dest1[row] * h),
expert_grad + (dest2[row] * h), h, mask1[row], indicator2);
}
// COMBINE KERNELS --------------------------------
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(src_row1, dst_row, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(src_row2, dst_row, weight2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, T *tks_row1, T *tks_row2,
T *wt_grad1, T *wt_grad2, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1,
wt_grad2, weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(src_row1, dst_row, tks_row1,
wt_grad1, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(src_row2, dst_row, tks_row2,
wt_grad2, weight2, cols);
else
return;
}
template <typename T, int block_size, int pack_size>
__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T *logits, int *mask1, int *mask2, int *dest1,
int *dest2, const int e, const int c,
const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e);
moe_cb_fwd_selector<T, block_size, pack_size>(
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row],
indicator2);
}
template <typename T, int block_size, int pack_size>
__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad, int *mask1,
int *mask2, int *dest1, int *dest2,
const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
moe_cb_bwd_selector<T, block_size, pack_size>(
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
tokens_grad + (row * h), h, tks + (dest1[row] * h),
tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1],
row_log[eid2], mask1[row], indicator2);
}
// CUMSUM KERNEL --------------------------------
template <int block_size, int pack_size>
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
const int e) {
assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
__shared__ int temp[block_size + 1];
int pack[pack_size];
for (int idx = 0; idx < s; idx += bpack_size) {
int offset = 1;
if (idx + tps < s) {
temp[tid] = inputs[tps * e + bid];
#pragma unroll
for (int i = 1; i < pack_size; ++i) {
pack[i] = inputs[(tps + i) * e + bid];
}
#pragma unroll
for (int i = 1; i < pack_size; ++i) {
temp[tid] += pack[i];
}
}
for (int i = block_size >> 1; i > 0; i >>= 1) {
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1;
temp[j + offset] += temp[j];
}
offset <<= 1;
}
if (tid == 0) {
temp[block_size] = temp[block_size - 1];
temp[block_size - 1] = 0;
}
for (int i = 1; i < block_size; i <<= 1) {
offset >>= 1;
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j];
temp[j] = temp[k];
temp[k] += ts;
}
}
__syncthreads();
if (tid == 0) temp[0] = temp[block_size];
__syncthreads();
if (idx + tps < s) {
temp[tid + 1] += last_sum;
#pragma unroll
for (int i = pack_size - 1; i > 0; --i) {
outputs[(tps + i) * e + bid] = temp[tid + 1];
temp[tid + 1] -= pack[i];
}
outputs[tps * e + bid] = temp[tid + 1];
}
__syncthreads();
last_sum += temp[0];
inputs += bpack_size * e;
outputs += bpack_size * e;
}
}
// LAUNCH FUNCTIONS --------------------------------
template <typename T>
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
int *mask2, int *dest1, int *dest2, const int s,
const int h) {
if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 512)
moe_dpch_fwd_kernel<T, 32, 8>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 1024)
moe_dpch_fwd_kernel<T, 32, 16>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 2048)
moe_dpch_fwd_kernel<T, 64, 16>
<<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else
moe_dpch_fwd_kernel<T, 128, 16>
<<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
}
template <typename T>
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
int *dest1, int *dest2, const int s, const int h) {
if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 512)
moe_dpch_bwd_kernel<T, 32, 8>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 1024)
moe_dpch_bwd_kernel<T, 32, 16>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 2048)
moe_dpch_bwd_kernel<T, 64, 16>
<<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else
moe_dpch_bwd_kernel<T, 128, 16>
<<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
}
template <typename T>
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2, int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {
if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 512)
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 1024)
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
e, c, h);
else if (h < 2048)
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
e, c, h);
else
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1,
dest2, e, c, h);
}
template <typename T>
void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *logits_grad, int *mask1, int *mask2, int *dest1,
int *dest2, const int s, const int e, const int c,
const int h) {
if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h);
else // if (h < 512)
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h);
// else if (h < 1024)
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
// else
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
}
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512)
cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e);
else if (s <= 1024)
cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e);
else if (s <= 2048)
cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e);
else
cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
}
// API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
}
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros(
{ec, h},
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
batch_tokens.scalar_type(), "moe dispatch forward",
moe_dpch_fwd_launch<scalar_t>(
batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(),
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
return res;
}
torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros(
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
expert_grad.scalar_type(), "moe dispatch backward",
moe_dpch_bwd_launch<scalar_t>(
res.data<scalar_t>(), expert_grad.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(),
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
return res;
}
torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype());
auto res = torch::zeros(
{s, h},
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
expert_tokens.scalar_type(), "moe combine forward",
moe_cb_fwd_launch<scalar_t>(
expert_tokens.data<scalar_t>(), res.data<scalar_t>(),
logits.data<scalar_t>(), mask[0].data<int>(),
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
h));
return res;
}
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype());
auto egrad = torch::zeros(
{e * c, h},
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
wgrad = torch::zeros(
{s, e}, torch::dtype(logits.dtype()).device(logits.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
tokens_grad.scalar_type(), "moe combine backward",
moe_cb_bwd_launch<scalar_t>(
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(),
expert_tokens.data<scalar_t>(), logits.data<scalar_t>(),
wgrad.data<scalar_t>(), mask[0].data<int>(),
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
h));
return {egrad, wgrad};
}
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32);
const int s = mask.size(0), e = mask.size(1);
auto res =
torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
return res;
}

View File

@@ -0,0 +1,146 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template <typename T_g, typename T_p>
struct AdamFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta1_correction,
const float beta2_correction, const float epsilon, const float lr,
adamMode_t mode, const float decay, const float div_scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_g *g = (T_g *)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T_p *p = (T_p *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T_p *m = (T_p *)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T_p *v = (T_p *)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (div_scale > 0) r_g[ii] /= div_scale;
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction, const float weight_decay,
const float div_scale) {
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
DISPATCH_FLOAT_AND_HALF_FOR_G_P(
tensor_lists[0][0].scalar_type(), tensor_lists[1][0].scalar_type(), 0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<g_scalar_t_0, p_scalar_t_0>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay, div_scale);)
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@@ -0,0 +1,130 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h"
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template <int n>
struct TensorListMetadata {
void *addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a
// full int.
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
volatile int *noop_flag, T tl,
U callable, ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however
// it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
int block_size, int chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size();
l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0,
"Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory =
(contiguous_memory ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
"Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor =
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) {
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3
// << std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} else {
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3
// << std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}

View File

@@ -0,0 +1,382 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_l2norm_kernel.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
int src_offset) {
typedef
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}
template <typename x_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be
// sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] += next * next;
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] += next * next;
}
}
}
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
max_chunks_per_tensor +
chunk_idx] = final;
}
}
};
// Probably better to template, but since we are not likely to support other
// norm
template <typename x_t>
struct MaxNormFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t *x = (x_t *)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be
// sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
}
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) {
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_x, x, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
float next = static_cast<float>(r_x[ii]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
float next = static_cast<float>(x[i]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
float final = reduce_block_into_lanes_max_op(s_vals, val);
if (threadIdx.x == 0) {
if (!isfinite(final))
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
max_chunks_per_tensor +
chunk_idx] = final;
}
}
};
__global__ void cleanup(float *output, float *output_per_tensor, float *ret,
float *ret_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
__shared__ float vals[512];
if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(final);
}
if (per_tensor) {
float *output_this_tensor =
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
}
}
__global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
float *ret_per_tensor, bool per_tensor,
int max_chunks_per_tensor, int norm_type,
float alpha, float beta) {
__shared__ float vals[512];
if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x];
if (norm_type == 0) {
float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
} else {
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
}
}
if (per_tensor) {
float *output_this_tensor =
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
if (norm_type == 0) {
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val = fmaxf(fabsf(val), fabsf(output_this_tensor[i]));
float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] =
alpha * ret_per_tensor[blockIdx.x] + beta * final;
} else {
float val = 0;
for (int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] *
ret_per_tensor[blockIdx.x] +
beta * final);
}
}
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor =
per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor =
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor =
at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
per_tensor, max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to
// end. I could get rid of these by hacking the functor + multi tensor harness
// with persistence logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.DATA_PTR<float>(),
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
ret.DATA_PTR<float>(),
per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr, per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
// Compute and update grad norm
// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
// L-2: gn = sqrt(a * gn^2 + b * n^2)
// L-inf: gn = a * gn + b * n
void multi_tensor_norm_out_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, at::Tensor out,
const float alpha, const float beta, const int norm_type) {
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(),
"noop flag should be on the same device as tensors");
// we don't need global thus uses empty here
auto output = at::empty({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor =
(tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
// Although it is single write then read, still need to be zero
// Since tailing element also participate cleanup
output_per_tensor =
at::zeros({ntensors * max_chunks_per_tensor}, float_options);
if (norm_type == 0) {
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
MaxNormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
output_per_tensor.DATA_PTR<float>(), true, max_chunks_per_tensor);)
} else {
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
output_per_tensor.DATA_PTR<float>(), true, max_chunks_per_tensor);)
}
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to
// end. I could get rid of these by hacking the functor + multi tensor harness
// with persistence logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
// Adding the following device guard since it happens sometimes that the
// tensors are on one device and the cuda stream is on another device which
// results in ILLEGAL MEM ACCESS error.
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup_v2<<<ntensors, 512, 0, stream>>>(
output.DATA_PTR<float>(), output_per_tensor.DATA_PTR<float>(),
ret.DATA_PTR<float>(), out.DATA_PTR<float>(), true, max_chunks_per_tensor,
norm_type, alpha, beta);
return;
}

View File

@@ -0,0 +1,354 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_lamb.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
int src_offset) {
typedef
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}
typedef enum {
MOMENT_MODE_0 = 0, // L2 regularization mode
MOMENT_MODE_1 = 1 // Decoupled weight decay mode
} adamMode_t;
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
using MATH_T = float;
template <typename T>
struct LAMBStage1Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl,
const float beta1, const float beta2, const float beta3,
const float beta1_correction, const float beta2_correction,
const float epsilon, adamMode_t mode, const float decay,
const float *global_grad_norm, const float max_global_grad_norm) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm =
(*global_grad_norm) > max_global_grad_norm
? (*global_grad_norm) / max_global_grad_norm
: 1.0f;
T *g = (T *)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T *p = (T *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T *m = (T *)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T *v = (T *)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(g) &&
is_aligned(p) && is_aligned(m) && is_aligned(v)) {
T l_g[ILP];
T l_p[ILP];
T l_m[ILP];
T l_v[ILP];
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(l_g, g, 0, i_start);
if (decay != 0) load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_g[ii] = l_g[ii];
if (decay == 0) {
r_p[ii] = MATH_T(0);
} else {
r_p[ii] = l_p[ii];
}
r_m[ii] = l_m[ii];
r_v[ii] = l_v[ii];
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay * r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
} else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
l_p[ii] = r_p[ii];
l_m[ii] = r_m[ii];
l_v[ii] = r_v[ii];
}
// store
load_store(g, l_p, i_start, 0);
load_store(m, l_m, i_start, 0);
load_store(v, l_v, i_start, 0);
}
} else {
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
} else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay * r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
} else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1 - beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased / denom) + (decay * r_p[ii]);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
g[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template <typename T>
struct LAMBStage2Functor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<2> &tl,
const float *per_tensor_param_norm, const float *per_tensor_update_norm,
const float learning_rate, const float decay, bool use_nvlamb) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0)) {
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f)
? learning_rate * (param_norm / update_norm)
: learning_rate;
}
T *update = (T *)tl.addresses[0][tensor_loc];
update += chunk_idx * chunk_size;
T *p = (T *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) &&
is_aligned(update)) {
T r_p[ILP];
T r_update[ILP];
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_p, p, 0, i_start);
load_store(r_update, update, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_p[ii] = static_cast<MATH_T>(r_p[ii]) -
(ratio * static_cast<MATH_T>(r_update[ii]));
}
load_store(p, r_p, i_start, 0);
}
} else {
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_p[ILP];
MATH_T r_update[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
}
}
}
}
}
};
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int bias_correction,
const float weight_decay, const int grad_averaging,
const int mode, at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python) {
using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by
// this So we assume every tensor are all in the same type
bool use_nvlamb =
use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
// Handle grad averaging mode
float beta3 = 1.0f;
if (grad_averaging == 1) beta3 = 1 - beta1;
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(),
tensor_lists.begin() + 1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin() + 1,
tensor_lists.begin() + 2);
// Compute per tensor param norm
auto param_norm_tuple =
multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all
// the time We can also grab list of empty tensor to avoid this, but I'd like
// to save space/cpu code
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
LAMBStage1Functor<scalar_t_0>(), beta1, beta2,
beta3, // 1-beta1 or 1 depends on averaging mode
bias_correction1, bias_correction2, epsilon,
(adamMode_t)mode, weight_decay,
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
// Compute update norms
auto update_norm_tuple =
multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
std::vector<std::vector<at::Tensor>> grad_param_list(
tensor_lists.begin(), tensor_lists.begin() + 2);
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list,
LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
lr, weight_decay, use_nvlamb);)
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@@ -0,0 +1,125 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
template <typename T>
__device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
int src_offset) {
typedef
typename std::aligned_storage<ILP * sizeof(T), ILP * alignof(T)>::type LT;
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}
template <typename in_t, typename out_t>
struct ScaleFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int *noop_gmem,
TensorListMetadata<2> &tl,
float scale) {
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
in_t *in = (in_t *)tl.addresses[0][tensor_loc];
in += chunk_idx * chunk_size;
out_t *out = (out_t *)tl.addresses[1][tensor_loc];
out += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
bool finite = true;
in_t r_in[ILP];
out_t r_out[ILP];
// to make things simple, we put aligned case in a different code path
if (n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) &&
is_aligned(out)) {
for (int i_start = threadIdx.x;
i_start * ILP < n && i_start * ILP < chunk_size;
i_start += blockDim.x) {
// load
load_store(r_in, in, 0, i_start);
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
// store
load_store(out, r_out, i_start, 0);
}
} else {
// Non-divergent exit condition for __syncthreads, not necessary here
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point
// unrolling the write loop, since writes just fire off once their LDGs
// arrive. Put another way, the STGs are dependent on the LDGs, but not
// on each other. There is still compute ILP benefit from unrolling the
// loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) out[i] = r_out[ii];
}
}
}
if (!finite)
*noop_gmem =
1; // Blindly fire off a write. These will race but that's ok.
}
};
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale) {
using namespace at;
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_AND_HALF(
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(),
scale);))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}

View File

@@ -0,0 +1,167 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_sgd_kernel.cu
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_runtime.h>
#include "compat.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
/**
* Perform fused SGD on multiple buffers
* N: number of tensors
* tl[0] : gradients
* tl[1] : weights
* tl[2] : momentum buffers
* tl[3] : fp16 weights (if appropriate)
* wd : weight_decay (scalar)
* momentum : momentum (scalar)
* dampening : momentum dampening (scalar)
* lr : learning rate (scalar)
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
template <typename T_grad, typename T_weight>
struct SGDFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,
float wd, float momentum, float dampening, float lr, bool nesterov,
bool first_run, bool wd_after_momentum, float scale) {
// Early exit if we don't need to do anything
if (*noop_gmem) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_grad *grad_in = (T_grad *)tl.addresses[0][tensor_loc];
grad_in += chunk_idx * chunk_size;
T_weight *weight_in = (T_weight *)tl.addresses[1][tensor_loc];
weight_in += chunk_idx * chunk_size;
T_weight *mom_in = (T_weight *)tl.addresses[2][tensor_loc];
mom_in += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for (int i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
incoming_grads[ii] = static_cast<float>(grad_in[i]) * scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
// apply weight decay before momentum if necessary
if (wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
if (momentum != 0.f) {
if (!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum +
(1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
if (nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii];
else
incoming_grads[ii] = incoming_moms[ii];
}
// Apply WD after momentum if desired
if (wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// also write out the new momentum
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
}
}
}
}
};
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd, float momentum, float dampening, float lr,
bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type
// 1. fp16, fp16, fp16
// 2. fp32, fp32, fp32
// 3. fp16, fp32, fp32
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<at::Half, at::Half>(), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum,
scale);
}
// Case 2. fp32, fp32, fp32
else if (grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<float, float>(), wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 3. fp16, fp32, fp32
else if (grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float && num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<at::Half, float>(), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum,
scale);
} else {
AT_ERROR(
"multi_tensor_sgd only supports some combinations of gradient & weight "
"types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type,
", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@@ -0,0 +1,70 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
int attn_heads);
torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results, float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches,
attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::
get_batch_per_block,
"Return Batch per block size.");
}

View File

@@ -0,0 +1,538 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#pragma once
#include <assert.h>
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat>
#include <limits>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
template <typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional
* features 1) input scaling 2) Explicit masking
*/
template <typename input_t, typename output_t, typename acc_t,
int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
int micro_batch_size, int element_count, int pad_batches) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch =
(blockDim.y *
(blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
threadIdx.y) *
WARP_BATCH;
int pad_first_batch = 0;
if (pad_batches != 1) { // bert style
pad_first_batch =
(blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
WARP_BATCH;
} else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i * element_count + it * WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t,
int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
int micro_batch_size, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset =
first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_grad, grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_output, output + i * element_count + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] =
(acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] =
(output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
int attn_heads) {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
const uint8_t *mask,
const input_t scale,
int query_seq_len, int key_seq_len,
int batches, int attn_heads,
int pad_batches) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_forward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);
dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
}
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len, int key_seq_len,
int batches, int attn_heads) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}

View File

@@ -0,0 +1,89 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches,
int attn_heads) {
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len,
// seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results = torch::empty(
{batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(), "dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr), scale_factor,
query_seq_len, key_seq_len, batches, attn_heads, pad_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len,
// seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(), "dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, query_seq_len, key_seq_len, batches, attn_heads););
// backward pass is completely in-place
return output_grads;
}
} // namespace scaled_masked_softmax
} // namespace fused_softmax
} // namespace multihead_attn

View File

@@ -0,0 +1,54 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(torch::Tensor const& output_grads,
torch::Tensor const& softmax_results, float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}

View File

@@ -0,0 +1,600 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#pragma once
#include <assert.h>
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat>
#include <limits>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst) {
*dst = 0.0;
}
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst) {
*((float2 *)dst) = make_float2(0.0f, 0.0f);
}
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) {
*dst = 0.0;
}
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) {
*((float2 *)dst) = make_float2(0.0f, 0.0f);
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
template <typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional
* features 1) input scaling 2) Implicit time (diagonal masking)
*/
template <typename input_t, typename output_t, typename acc_t,
int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size,
int stride, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
blockIdx.x;
int local_seq = blockIdx.x + 1;
int warp_iteration_limit =
(local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_data, src + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if ((element_index + element) < batch_element_count) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] =
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (it < warp_iteration_limit) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < local_seq) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count * stride + it * WARP_SIZE, out);
} else if (element_index < element_count) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(
dst + i * element_count * stride + it * WARP_SIZE);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t,
int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
int micro_batch_size, int stride, int element_count) {
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH +
blockIdx.x;
int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the
// batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : local_seq;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
temp_output, output + i * element_count * stride + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (element_index + element < batch_element_count) {
grad_reg[i][it + element] =
(acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches) break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] =
(output_t)(scale * (grad_reg[i][it + element] -
output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
gradInput + i * element_count * stride + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
output_t *dst, const input_t *src, const input_t scale,
int softmax_elements, int softmax_elements_stride, int attn_batches) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_forward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t,
acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
dst, src, scale, batch_count, softmax_elements_stride,
softmax_elements);
break;
default:
break;
}
}
}
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
output_t *grad_input, input_t *grad, const input_t *output,
const acc_t scale, int softmax_elements, int softmax_elements_stride,
int attn_batches) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
if (softmax_elements == 0) {
return;
} else {
int log2_elements = log2_ceil(softmax_elements);
const int next_power_of_two = 1 << log2_elements;
int seq_len = softmax_elements;
int batch_count = attn_batches * seq_len;
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t,
acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input, grad, output, scale, batch_count,
softmax_elements_stride, softmax_elements);
break;
default:
break;
}
}
}

View File

@@ -0,0 +1,75 @@
/*This code from NVIDIA Megatron:
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) {
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t,
float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr), scale_factor, seq_len,
seq_len, attn_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 3d tensor with dimensions [attn_batches, seq_len,
// seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t,
float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor, seq_len, seq_len, attn_batches););
// backward pass is completely in-place
return output_grads;
}
} // namespace scaled_upper_triang_masked_softmax
} // namespace fused_softmax
} // namespace multihead_attn

View File

@@ -0,0 +1,279 @@
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::Half) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Half && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else { \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
"'"); \
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
T *x, T val, int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final =
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}