[NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#636)

This commit is contained in:
Sze-qq 2022-04-02 13:28:57 +08:00 committed by binmakeswell
parent 6fcb381801
commit 10591ecdf9

View File

@ -20,446 +20,447 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE SOFTWARE
*/ */
#include "cpu_adam.h" #include "cpu_adam.h"
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream> #include <iostream>
#include <math.h>
#include <memory> #include <memory>
#include <omp.h>
#include <string.h>
#include <torch/extension.h>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <string.h>
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers; static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
// C++ interface // C++ interface
void Adam_Optimizer::Step_1(float* _params, void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float* grads, float *_exp_avg_sq, size_t _param_size,
float* _exp_avg, bool param_half_precision, bool grad_half_precision,
float* _exp_avg_sq, float loss_scale) {
size_t _param_size, size_t rounded_size = 0;
bool param_half_precision,
bool grad_half_precision,
float loss_scale)
{
size_t rounded_size = 0;
float betta1_minus1 = 1 - _betta1; float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2; float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1; float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay; float w_decay = -1 * _alpha * _weight_decay;
__half* params_cast_h = NULL; __half *params_cast_h = NULL;
__half* grads_cast_h = NULL; __half *grads_cast_h = NULL;
if (param_half_precision) { if (param_half_precision) {
params_cast_h = reinterpret_cast<__half*>(_params); params_cast_h = reinterpret_cast<__half *>(_params);
} }
if (grad_half_precision) { if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half*>(grads); grads_cast_h = reinterpret_cast<__half *>(grads);
} }
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4; AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1); betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4; AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2); betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4; AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1); betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4; AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1); betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt; AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2); bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4; AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps); eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4; AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size); step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay_4; AVX_Data weight_decay_4;
if (_weight_decay > 0) if (_weight_decay > 0)
weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); weight_decay_4.data =
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
size_t offset = copy_size + t; copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) { for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4; AVX_Data grad_4;
if (grad_half_precision) { if (grad_half_precision) {
grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i);
} else { } else {
grad_4.data = SIMD_LOAD(grads + i); grad_4.data = SIMD_LOAD(grads + i);
} }
if (loss_scale > 0) { if (loss_scale > 0) {
AVX_Data loss_scale_vec; AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale); loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
} }
AVX_Data momentum_4; AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i); momentum_4.data = SIMD_LOAD(_exp_avg + i);
AVX_Data variance_4; AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i); variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
AVX_Data param_4; AVX_Data param_4;
if (param_half_precision) { if (param_half_precision) {
param_4.data = SIMD_LOAD_HALF(params_cast_h + i); param_4.data = SIMD_LOAD_HALF(params_cast_h + i);
} else { } else {
param_4.data = SIMD_LOAD(_params + i); param_4.data = SIMD_LOAD(_params + i);
} }
if (_weight_decay > 0 && !_adamw_mode) { if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
} }
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data); momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data); momentum_4.data =
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data); SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data); variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data); grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
grad_4.data = SIMD_SQRT(variance_4.data); variance_4.data =
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data); SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data); grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
if (_weight_decay > 0 && _adamw_mode) { if (_weight_decay > 0 && _adamw_mode) {
param_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data); param_4.data =
} SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data);
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); }
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
if (param_half_precision) { if (param_half_precision) {
SIMD_STORE_HALF((float*)(params_cast_h + i), param_4.data); SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data);
} else { } else {
SIMD_STORE(_params + i, param_4.data); SIMD_STORE(_params + i, param_4.data);
} }
SIMD_STORE(_exp_avg + i, momentum_4.data); SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data); SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
} }
}
#endif #endif
if (_param_size > rounded_size) { if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) { for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t; if ((t + TILE) > _param_size)
size_t offset = copy_size + t; copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
for (size_t k = t; k < offset; k++) { for (size_t k = t; k < offset; k++) {
float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k]; float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k];
if (loss_scale > 0) { grad /= loss_scale; } if (loss_scale > 0) {
float param = param_half_precision ? (float)params_cast_h[k] : _params[k]; grad /= loss_scale;
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
if (param_half_precision)
params_cast_h[k] = (__half)param;
else
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
} }
float param =
param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
if (param_half_precision)
params_cast_h[k] = (__half)param;
else
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
} }
}
} }
void Adam_Optimizer::Step_4(float* _params, void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float* grads, float *_exp_avg_sq, size_t _param_size,
float* _exp_avg, bool param_half_precision, bool grad_half_precision,
float* _exp_avg_sq, float loss_scale) {
size_t _param_size, size_t rounded_size = 0;
bool param_half_precision,
bool grad_half_precision, __half *params_cast_h = NULL;
float loss_scale) __half *grads_cast_h = NULL;
{ if (param_half_precision) {
size_t rounded_size = 0; params_cast_h = reinterpret_cast<__half *>(_params);
}
__half* params_cast_h = NULL; if (grad_half_precision) {
__half* grads_cast_h = NULL; grads_cast_h = reinterpret_cast<__half *>(grads);
if (param_half_precision) { }
params_cast_h = reinterpret_cast<__half*>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half*>(grads);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4; AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1); betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4; AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2); betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1; float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4; AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1); betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2; float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4; AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1); betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt; AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2); bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4; AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps); eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1; float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4; AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size); step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay; float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4; AVX_Data weight_decay_4;
if (_weight_decay > 0) if (_weight_decay > 0)
weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); weight_decay_4.data =
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
size_t offset = copy_size + t; copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
AVX_Data grad_4[4]; AVX_Data grad_4[4];
AVX_Data momentum_4[4]; AVX_Data momentum_4[4];
AVX_Data variance_4[4]; AVX_Data variance_4[4];
AVX_Data param_4[4]; AVX_Data param_4[4];
#pragma unroll 4 #pragma unroll 4
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
if (grad_half_precision) { if (grad_half_precision) {
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j);
} else { } else {
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
}
if(loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
} else {
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data = SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data = SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) {
SIMD_STORE_HALF((float*)(params_cast_h + i + SIMD_WIDTH * j), param_4[j].data);
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
}
} }
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
} else {
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) {
SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j),
param_4[j].data);
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
}
} }
}
#endif #endif
if (_param_size > rounded_size) if (_param_size > rounded_size)
Step_1((param_half_precision ? (float*)(params_cast_h + rounded_size) : _params + rounded_size), Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
(grad_half_precision ? (float*)(grads_cast_h + rounded_size) : grads + rounded_size), : _params + rounded_size),
(_exp_avg + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size)
(_exp_avg_sq + rounded_size), : grads + rounded_size),
(_param_size - rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
param_half_precision, (_param_size - rounded_size), param_half_precision,
grad_half_precision, grad_half_precision, loss_scale);
loss_scale);
} }
int create_adam_optimizer(int optimizer_id, int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float betta1 = 0.9, float eps = 1e-8, float weight_decay = 0,
float betta2 = 0.999, bool adamw_mode = true, bool should_log = false) {
float eps = 1e-8, auto opt = std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps,
float weight_decay = 0, weight_decay, adamw_mode);
bool adamw_mode = true,
bool should_log = false)
{
auto opt =
std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
s_optimizers[optimizer_id] = opt; s_optimizers[optimizer_id] = opt;
if (should_log){ if (should_log) {
std::string avx_type = ""; std::string avx_type = "";
#if defined(__AVX512__) #if defined(__AVX512__)
avx_type = "AVX512"; avx_type = "AVX512";
#else #else
#if defined(__AVX256__) or defined(__AVX2__) #if defined(__AVX256__) or defined(__AVX2__)
avx_type = "AVX2"; avx_type = "AVX2";
#else #else
avx_type = "scalar"; avx_type = "scalar";
#endif #endif
#endif #endif
printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", printf("Adam Optimizer #%d is created with %s arithmetic capability.\n",
optimizer_id, optimizer_id, avx_type.c_str());
avx_type.c_str()); printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", alpha, betta1, betta2, weight_decay, (int)adamw_mode);
alpha, }
betta1,
betta2,
weight_decay,
(int)adamw_mode);
}
return 0; return 0;
} }
void Adam_Optimizer::Step_8(float* _params, void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float* grads, float *_exp_avg_sq, size_t _param_size,
float* _exp_avg, bool param_half_precision, bool grad_half_precision,
float* _exp_avg_sq, float loss_scale) {
size_t _param_size, size_t rounded_size = 0;
bool param_half_precision, __half *params_cast_h = NULL;
bool grad_half_precision, __half *grads_cast_h = NULL;
float loss_scale) if (param_half_precision) {
{ params_cast_h = reinterpret_cast<__half *>(_params);
size_t rounded_size = 0; }
__half* params_cast_h = NULL; if (grad_half_precision) {
__half* grads_cast_h = NULL; grads_cast_h = reinterpret_cast<__half *>(grads);
if (param_half_precision) { }
params_cast_h = reinterpret_cast<__half*>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half*>(grads);
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4; AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1); betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4; AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2); betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1; float betta1_minus1 = 1 - _betta1;
AVX_Data betta1_minus1_4; AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1); betta1_minus1_4.data = SIMD_SET(betta1_minus1);
float betta2_minus1 = 1 - _betta2; float betta2_minus1 = 1 - _betta2;
AVX_Data betta2_minus1_4; AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1); betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt; AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2); bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4; AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps); eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1; float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4; AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size); step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay; float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay_4; AVX_Data weight_decay_4;
if (_weight_decay > 0) if (_weight_decay > 0)
weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); weight_decay_4.data =
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
size_t offset = copy_size + t; copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
AVX_Data grad_4[8]; AVX_Data grad_4[8];
AVX_Data momentum_4[8]; AVX_Data momentum_4[8];
AVX_Data variance_4[8]; AVX_Data variance_4[8];
AVX_Data param_4[8]; AVX_Data param_4[8];
#pragma unroll 8 #pragma unroll 8
for (int j = 0; j < 8; j++) { for (int j = 0; j < 8; j++) {
if (grad_half_precision) { if (grad_half_precision) {
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j);
} else { } else {
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
}
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
} else {
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data = SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data = SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) {
SIMD_STORE_HALF((float*)(params_cast_h + i + SIMD_WIDTH * j), param_4[j].data);
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
}
} }
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
} else {
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data);
}
momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data);
momentum_4[j].data =
SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data);
variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data);
grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data);
variance_4[j].data =
SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data);
grad_4[j].data = SIMD_SQRT(variance_4[j].data);
grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data);
grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j].data =
SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data);
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) {
SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j),
param_4[j].data);
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
}
} }
}
#endif #endif
if (_param_size > rounded_size) if (_param_size > rounded_size)
Step_4((param_half_precision ? (float*)(params_cast_h + rounded_size) : _params + rounded_size), Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
(grad_half_precision ? (float*)(grads_cast_h + rounded_size) : grads + rounded_size), : _params + rounded_size),
(_exp_avg + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size)
(_exp_avg_sq + rounded_size), : grads + rounded_size),
(_param_size - rounded_size), (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
param_half_precision, (_param_size - rounded_size), param_half_precision,
grad_half_precision, grad_half_precision, loss_scale);
loss_scale);
} }
int adam_step(int optimizer_id, int adam_step(int optimizer_id,
@ -501,17 +502,13 @@ int adam_step(int optimizer_id,
return 0; return 0;
} }
int destroy_adam_optimizer(int optimizer_id) {
s_optimizers.erase(optimizer_id);
int destroy_adam_optimizer(int optimizer_id) return 0;
{
s_optimizers.erase(optimizer_id);
return 0;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
{ m.def("adam_update", &adam_step, "CPU Adam update (C++)");
m.def("adam_update", &adam_step, "CPU Adam update (C++)"); m.def("create_adam", &create_adam_optimizer, "CPU Adam (C++)");
m.def("create_adam", &create_adam_optimizer, "CPU Adam (C++)"); m.def("destroy_adam", &destroy_adam_optimizer, "CPU Adam destroy (C++)");
m.def("destroy_adam", &destroy_adam_optimizer, "CPU Adam destroy (C++)");
} }