[kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)

* [kernel] support pure fp16 for cpu adam (#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test
This commit is contained in:
Hongxin Liu
2023-10-16 21:56:53 +08:00
committed by GitHub
parent 7768afbad0
commit 4f68b3f10c
8 changed files with 148 additions and 136 deletions

View File

@@ -35,23 +35,19 @@ SOFTWARE
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
size_t rounded_size = 0;
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
@@ -77,7 +73,6 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
@@ -87,28 +82,23 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
if (grad_half_precision) {
grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i);
} else {
grad_4.data = SIMD_LOAD(grads + i);
}
this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data);
}
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
this->simd_load(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
this->simd_load(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
AVX_Data param_4;
if (param_half_precision) {
param_4.data = SIMD_LOAD_HALF(params_cast_h + i);
} else {
param_4.data = SIMD_LOAD(_params + i);
}
this->simd_load(param_half_precision, _params + i, params_cast_h + i,
param_4);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data);
@@ -130,13 +120,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
}
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
if (param_half_precision) {
SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data);
} else {
SIMD_STORE(_params + i, param_4.data);
}
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
this->simd_store(param_half_precision, _params + i, params_cast_h + i,
param_4);
this->simd_store(momentum_half_precision, _exp_avg + i,
momentum_cast_h + i, momentum_4);
this->simd_store(variance_half_precision, _exp_avg_sq + i,
variance_cast_h + i, variance_4);
}
}
#endif
@@ -154,8 +143,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
}
float param =
param_half_precision ? (float)params_cast_h[k] : _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
float momentum =
momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k];
float variance = variance_half_precision ? (float)variance_cast_h[k]
: _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
@@ -178,8 +169,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
params_cast_h[k] = (__half)param;
else
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
if (momentum_half_precision)
momentum_cast_h[k] = (__half)(momentum);
else
_exp_avg[k] = momentum;
if (variance_half_precision)
variance_cast_h[k] = (__half)(variance);
else
_exp_avg_sq[k] = variance;
}
}
}
@@ -188,17 +185,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
size_t rounded_size = 0;
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
@@ -228,7 +222,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
@@ -243,26 +236,21 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
AVX_Data param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
if (grad_half_precision) {
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j);
} else {
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
}
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
} else {
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
@@ -285,14 +273,13 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
}
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) {
SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j),
param_4[j].data);
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data);
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_store(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
}
}
}
@@ -302,24 +289,26 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale);
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
float loss_scale) {
size_t rounded_size = 0;
__half *params_cast_h = NULL;
__half *grads_cast_h = NULL;
if (param_half_precision) {
params_cast_h = reinterpret_cast<__half *>(_params);
}
if (grad_half_precision) {
grads_cast_h = reinterpret_cast<__half *>(grads);
}
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
@@ -348,7 +337,6 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
if (_weight_decay > 0)
weight_decay_4.data =
(_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
@@ -363,26 +351,21 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
AVX_Data param_4[8];
#pragma unroll 8
for (int j = 0; j < 8; j++) {
if (grad_half_precision) {
grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j);
} else {
grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j);
}
this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j,
grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]);
if (loss_scale > 0) {
AVX_Data loss_scale_vec;
loss_scale_vec.data = SIMD_SET(loss_scale);
grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data);
}
momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j);
variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j);
if (param_half_precision) {
param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j);
} else {
param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j);
}
this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_load(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j].data =
@@ -405,15 +388,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
param_4[j].data =
SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data);
if (param_half_precision) {
SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j),
param_4[j].data);
} else {
SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data);
}
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data);
this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j,
params_cast_h + i + SIMD_WIDTH * j, param_4[j]);
this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j,
momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]);
this->simd_store(variance_half_precision,
_exp_avg_sq + i + SIMD_WIDTH * j,
variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]);
}
}
}
@@ -423,9 +404,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
: grads + rounded_size),
(_exp_avg + rounded_size), (_exp_avg_sq + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, loss_scale);
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
@@ -447,7 +432,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf), loss_scale);
(grads.options().dtype() == at::kHalf),
(exp_avg.options().dtype() == at::kHalf),
(exp_avg_sq.options().dtype() == at::kHalf), loss_scale);
}
namespace py = pybind11;

View File

@@ -50,9 +50,9 @@ SOFTWARE
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#define SIMD_STORE_HALF(x, d) \
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
#define SIMD_WIDTH 8
@@ -66,9 +66,9 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#define SIMD_STORE_HALF(x, d) \
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#endif
@@ -83,11 +83,12 @@ union AVX_Data {
#endif
#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1);
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);
class Adam_Optimizer {
public:
@@ -141,6 +142,24 @@ class Adam_Optimizer {
}
}
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
data.data = SIMD_LOAD_HALF(h_ptr);
} else {
data.data = SIMD_LOAD(ptr);
}
}
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
SIMD_STORE_HALF(h_ptr, data.data);
} else {
SIMD_STORE(ptr, data.data);
}
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,