[NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945)

This commit is contained in:
bajiaoyu517 2022-05-13 17:42:50 +08:00 committed by binmakeswell
parent 8ffdc38376
commit eb9a81d72a

View File

@ -48,10 +48,10 @@ SOFTWARE
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x) #define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \ #define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \ #define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \ _mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__) #elif defined(__AVX256__) or defined(__AVX2__)
@ -66,8 +66,8 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \ #define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \ _mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#endif #endif
@ -83,19 +83,25 @@ union AVX_Data {
#endif #endif
#define STEP(SPAN) \ #define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \ float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \ bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1); bool grad_half_precision = false, float loss_scale = -1);
class Adam_Optimizer { class Adam_Optimizer {
public: public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0, float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true) bool adamw_mode = true)
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps), : _alpha(alpha),
_weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0), _betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {} _adamw_mode(adamw_mode) {}
~Adam_Optimizer() {} ~Adam_Optimizer() {}
@ -135,7 +141,7 @@ public:
} }
} }
private: private:
float _alpha; float _alpha;
float _betta1; float _betta1;
float _betta2; float _betta2;