From b73a048ad8cf549c5ce3cecb72820923754fb520 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Fri, 4 Mar 2022 16:05:15 +0800 Subject: [PATCH] [zero] cpu adam kernel (#288) * Added CPU Adam * finished the cpu adam * updated the license * delete useless parameters, removed resnet * modified the method off cpu adam unittest * deleted some useless codes * removed useless codes Co-authored-by: ver217 Co-authored-by: Frank Lee Co-authored-by: jiaruifang --- .../kernel/cuda_native/csrc/cpu_adam.cpp | 517 ++++++++++++++++++ colossalai/kernel/cuda_native/csrc/cpu_adam.h | 163 ++++++ colossalai/nn/optimizer/__init__.py | 3 +- colossalai/nn/optimizer/cpu_adam.py | 103 ++++ .../zero/sharded_optim/sharded_optim.py | 14 +- setup.py | 10 +- tests/test_optimizer/unittest_cpu_adam.py | 197 +++++++ 7 files changed, 1001 insertions(+), 6 deletions(-) create mode 100644 colossalai/kernel/cuda_native/csrc/cpu_adam.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/cpu_adam.h create mode 100644 colossalai/nn/optimizer/cpu_adam.py create mode 100644 tests/test_optimizer/unittest_cpu_adam.py diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp new file mode 100644 index 000000000..efd569fcd --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -0,0 +1,517 @@ +/* +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +*/ +#include "cpu_adam.h" +#include +#include +#include +#include +#include +#include +#include +#include + + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Adam_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + bool param_half_precision, + bool grad_half_precision, + float loss_scale) +{ + size_t rounded_size = 0; + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + + __half* params_cast_h = NULL; + __half* grads_cast_h = NULL; + + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half*>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half*>(grads); + } + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + AVX_Data grad_4; + if (grad_half_precision) { + grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); + } else { + grad_4.data = SIMD_LOAD(grads + i); + } + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); + } + AVX_Data momentum_4; + momentum_4.data = SIMD_LOAD(_exp_avg + i); + + AVX_Data variance_4; + variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + + AVX_Data param_4; + if (param_half_precision) { + param_4.data = SIMD_LOAD_HALF(params_cast_h + i); + } else { + param_4.data = SIMD_LOAD(_params + i); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); + } + momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data); + momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data); + variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data); + grad_4.data = SIMD_MUL(grad_4.data, grad_4.data); + variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data); + grad_4.data = SIMD_SQRT(variance_4.data); + grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data); + grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, param_4.data); + } + param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); + + if (param_half_precision) { + SIMD_STORE_HALF((float*)(params_cast_h + i), param_4.data); + } else { + SIMD_STORE(_params + i, param_4.data); + } + SIMD_STORE(_exp_avg + i, momentum_4.data); + SIMD_STORE(_exp_avg_sq + i, variance_4.data); + } + } +#endif + if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = grad_half_precision ? (float)grads_cast_h[k] : grads[k]; + if (loss_scale > 0) { grad /= loss_scale; } + float param = param_half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + param = grad * step_size + param; + + if (param_half_precision) + params_cast_h[k] = (__half)param; + else + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + } + } +} + +void Adam_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + bool param_half_precision, + bool grad_half_precision, + float loss_scale) +{ + size_t rounded_size = 0; + + __half* params_cast_h = NULL; + __half* grads_cast_h = NULL; + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half*>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half*>(grads); + } + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + float betta2_minus1 = 1 - _betta2; + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { + AVX_Data grad_4[4]; + AVX_Data momentum_4[4]; + AVX_Data variance_4[4]; + AVX_Data param_4[4]; +#pragma unroll 4 + for (int j = 0; j < 4; j++) { + if (grad_half_precision) { + grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); + } else { + grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); + } + + if(loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); + } + + momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); + variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + + if (param_half_precision) { + param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); + } else { + param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data); + } + momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data); + momentum_4[j].data = SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data); + variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data); + grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data); + variance_4[j].data = SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data); + grad_4[j].data = SIMD_SQRT(variance_4[j].data); + grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data); + grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data); + } + param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); + if (param_half_precision) { + SIMD_STORE_HALF((float*)(params_cast_h + i + SIMD_WIDTH * j), param_4[j].data); + } else { + SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); + } + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + } + } + } +#endif + if (_param_size > rounded_size) + Step_1((param_half_precision ? (float*)(params_cast_h + rounded_size) : _params + rounded_size), + (grad_half_precision ? (float*)(grads_cast_h + rounded_size) : grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + param_half_precision, + grad_half_precision, + loss_scale); +} + +int create_adam_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true, + bool should_log = false) +{ + auto opt = + std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); + + s_optimizers[optimizer_id] = opt; + + if (should_log){ + + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) or defined(__AVX2__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", + alpha, + betta1, + betta2, + weight_decay, + (int)adamw_mode); + } + + return 0; +} + +void Adam_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + bool param_half_precision, + bool grad_half_precision, + float loss_scale) +{ + size_t rounded_size = 0; + __half* params_cast_h = NULL; + __half* grads_cast_h = NULL; + if (param_half_precision) { + params_cast_h = reinterpret_cast<__half*>(_params); + } + if (grad_half_precision) { + grads_cast_h = reinterpret_cast<__half*>(grads); + } +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + float betta2_minus1 = 1 - _betta2; + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay_4; + if (_weight_decay > 0) + weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { + AVX_Data grad_4[8]; + AVX_Data momentum_4[8]; + AVX_Data variance_4[8]; + AVX_Data param_4[8]; +#pragma unroll 8 + for (int j = 0; j < 8; j++) { + if (grad_half_precision) { + grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); + } else { + grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); + } + + if (loss_scale > 0) { + AVX_Data loss_scale_vec; + loss_scale_vec.data = SIMD_SET(loss_scale); + grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); + } + + momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); + variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); + + if (param_half_precision) { + param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); + } else { + param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); + } + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, grad_4[j].data); + } + momentum_4[j].data = SIMD_MUL(momentum_4[j].data, betta1_4.data); + momentum_4[j].data = SIMD_FMA(grad_4[j].data, betta1_minus1_4.data, momentum_4[j].data); + variance_4[j].data = SIMD_MUL(variance_4[j].data, betta2_4.data); + grad_4[j].data = SIMD_MUL(grad_4[j].data, grad_4[j].data); + variance_4[j].data = SIMD_FMA(grad_4[j].data, betta2_minus1_4.data, variance_4[j].data); + grad_4[j].data = SIMD_SQRT(variance_4[j].data); + grad_4[j].data = SIMD_FMA(grad_4[j].data, bias2_sqrt.data, eps_4.data); + grad_4[j].data = SIMD_DIV(momentum_4[j].data, grad_4[j].data); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j].data = SIMD_FMA(param_4[j].data, weight_decay_4.data, param_4[j].data); + } + param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); + + if (param_half_precision) { + SIMD_STORE_HALF((float*)(params_cast_h + i + SIMD_WIDTH * j), param_4[j].data); + } else { + SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); + } + + SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data); + SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data); + } + } + } +#endif + if (_param_size > rounded_size) + Step_4((param_half_precision ? (float*)(params_cast_h + rounded_size) : _params + rounded_size), + (grad_half_precision ? (float*)(grads_cast_h + rounded_size) : grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + param_half_precision, + grad_half_precision, + loss_scale); +} + +int adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + float loss_scale) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.size(0), + (params.options().dtype() == at::kHalf), + (grads.options().dtype() == at::kHalf), + loss_scale); + + return 0; +} + + + +int destroy_adam_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("adam_update", &adam_step, "CPU Adam update (C++)"); + m.def("create_adam", &create_adam_optimizer, "CPU Adam (C++)"); + m.def("destroy_adam", &destroy_adam_optimizer, "CPU Adam destroy (C++)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h new file mode 100644 index 000000000..758722a8f --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -0,0 +1,163 @@ +/* +Copyright (c) Microsoft Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE +*/ +#pragma once + +#include +#include +#include +#include +#include + +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define TILE (128 * 1024 * 1024) + +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) + +#if defined(__AVX512__) +#define SIMD_WIDTH 16 +#define INTV __m256i +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_ADD(x, y) _mm512_add_ps(x, y) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_LOAD_HALF(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(x))) +#define SIMD_STORE_HALF(x, d) _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) + +#elif defined(__AVX256__) or defined(__AVX2__) +#define SIMD_WIDTH 8 +#define INTV __m128i +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_ADD(x, y) _mm256_add_ps(x, y) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x))) +#define SIMD_STORE_HALF(x, d) _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) + +#endif + +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#elif defined(__AVX256__) or defined(__AVX2__) + __m256 data; +#endif + // float data_f[16]; +}; + +#endif + + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + bool param_half_precision = false, \ + bool grad_half_precision = false, \ + float loss_scale = -1); + +class Adam_Optimizer { +public: + Adam_Optimizer(float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode){} + ~Adam_Optimizer(){} + + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) + { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + +private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; +}; diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index c084c5c86..afd14fab2 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -4,7 +4,8 @@ from .fused_lamb import FusedLAMB from .fused_sgd import FusedSGD from .lamb import Lamb from .lars import Lars +from .cpu_adam import CPUAdam __all__ = [ - 'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars' + 'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam' ] diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py new file mode 100644 index 000000000..21f607c48 --- /dev/null +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -0,0 +1,103 @@ +# modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/adam/cpu_adam.py + +import math +import torch +import time +from pathlib import Path +import colossalai + + +class CPUAdam(torch.optim.Optimizer): + optimizer_id = 0 + + def __init__(self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, + 0.999), + eps=1e-8, + weight_decay=0, + adamw_mode=True, + loss_scale=-1, + simd_log=False): + + default_args = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + bias_correction=bias_correction) + super(CPUAdam, self).__init__(model_params, default_args) + self.opt_id = CPUAdam.optimizer_id + CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 + self.adam_w_mode = adamw_mode + self.loss_scale = loss_scale + try: + import cpu_adam + except ImportError: + raise ImportError('Please install colossalai from source code to use CPUAdam') + self.cpu_adam_op = cpu_adam + self.cpu_adam_op.create_adam(self.opt_id, + lr, + betas[0], + betas[1], + eps, + weight_decay, + adamw_mode, + simd_log) + + def __del__(self): + self.cpu_adam_op.destroy_adam(self.opt_id) + + @torch.no_grad() + def step(self, closure=None): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # intended device for step + device = torch.device('cpu') + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \ + "sure the cpu_offload is Ture" + + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = 0 + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p.data, + dtype=torch.float, + device=device) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, + dtype=torch.float, + device=device) + # memory_format=torch.preserve_format) + + state['step'] += 1 + beta1, beta2 = group['betas'] + + self.cpu_adam_op.adam_update(self.opt_id, + state['step'], + group['lr'], + beta1, + beta2, + group['eps'], + group['weight_decay'], + group['bias_correction'], + p.data, + p.grad.data, + state['exp_avg'], + state['exp_avg_sq'], + self.loss_scale) + return loss diff --git a/colossalai/zero/sharded_optim/sharded_optim.py b/colossalai/zero/sharded_optim/sharded_optim.py index 2be7a2808..9dff355db 100644 --- a/colossalai/zero/sharded_optim/sharded_optim.py +++ b/colossalai/zero/sharded_optim/sharded_optim.py @@ -45,7 +45,9 @@ class ShardedOptimizer(ColossalaiOptimizer): mp_parallel_mode=ParallelMode.MODEL, # cpu offload - cpu_offload=False): + cpu_offload=False, + cpu_fp16_param=False, + cpu_fp16_grad=False): # TODO: add support for # 1. fp16 master weights @@ -63,6 +65,8 @@ class ShardedOptimizer(ColossalaiOptimizer): # cpu_offload self._cpu_offload = cpu_offload + self._cpu_fp16_param = cpu_fp16_param + self._cpu_fp16_grad = cpu_fp16_grad # get process groups self._dp_parallel_mode = dp_parallel_mode @@ -146,7 +150,11 @@ class ShardedOptimizer(ColossalaiOptimizer): # create a copy of fp32 weights of the parameters for which this rank is responsible fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id) - fp32_flat_current_rank = fp16_flat_current_rank.clone().float().detach() + # when using cpu offload, our cpu adam support fp16 paramters + if self._cpu_fp16_param: + fp32_flat_current_rank = fp16_flat_current_rank.detach() + else: + fp32_flat_current_rank = fp16_flat_current_rank.detach().float() device = 'cpu' if self._cpu_offload else get_current_device() fp32_flat_current_rank = fp32_flat_current_rank.to(device) fp32_flat_current_rank.requires_grad = True @@ -209,7 +217,7 @@ class ShardedOptimizer(ColossalaiOptimizer): fp32_partition_grad = torch.zeros_like(fp32_partition_param) fp32_partition_param.grad = fp32_partition_grad - # update the parameter with zero gradients for initialization of optimizer states + # update the parameter with zero gradients for initialization of optimizer stateus self._optimizer.step() # remove the grad of the paramter to save memory diff --git a/setup.py b/setup.py index 8b25229f7..292f82bd0 100644 --- a/setup.py +++ b/setup.py @@ -124,12 +124,12 @@ if build_cuda_ext: # https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] - def cuda_ext_helper(name, sources, extra_cuda_flags): + def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]): return CUDAExtension(name=name, sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources], include_dirs=[os.path.join( this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros + extra_cxx_flags, 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)}) @@ -188,6 +188,12 @@ if build_cuda_ext: 'kernels/general_kernels.cu', 'kernels/cuda_util.cu'], extra_cuda_flags + cc_flag)) + + extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] + ext_modules.append(cuda_ext_helper('cpu_adam', + ['cpu_adam.cpp'], + extra_cuda_flags, + extra_cxx_flags)) setup( name='colossalai', diff --git a/tests/test_optimizer/unittest_cpu_adam.py b/tests/test_optimizer/unittest_cpu_adam.py new file mode 100644 index 000000000..2f1e62174 --- /dev/null +++ b/tests/test_optimizer/unittest_cpu_adam.py @@ -0,0 +1,197 @@ +# BSD 3-Clause License +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, +# are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the psutil authors nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import math +import torch +import colossalai +try: + import cpu_adam +except ImportError: + raise ImportError("import cpu_adam error") + +def torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + bias_correction, + param, + grad, + exp_avg, + exp_avg_sq, + loss_scale, + use_adamw, +): + if loss_scale > 0: + grad.div_(loss_scale) + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + if weight_decay != 0: + if use_adamw: + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + else: + grad = grad.add(param, alpha=weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +class Test(): + def __init__(self): + self.opt_id = 0 + + def assertLess(self, data_diff, threshold, msg): + assert data_diff < threshold, msg + + def assertTrue(self, condition, msg): + assert condition, msg + + def check_res( + self, + step, + lr, + eps, + beta1, + beta2, + + weight_decay, + shape, + grad_dtype, + loss_scale, + use_adamw, + cpu_adam_op, + ): + p_data = torch.rand(shape, dtype=grad_dtype) + p_data_copy = p_data.clone().float() + p_grad = torch.rand(shape, dtype=grad_dtype) + if loss_scale > 0: + p_grad.mul_(loss_scale) + p_grad_copy = p_grad.clone().float() + exp_avg = torch.rand(shape) + exp_avg_copy = exp_avg.clone() + exp_avg_sq = torch.rand(shape) + exp_avg_sq_copy = exp_avg_sq.clone() + + cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, use_adamw, True) + cpu_adam_op.adam_update( + self.opt_id, + step, + lr, + beta1, + beta2, + eps, + weight_decay, + True, + p_data.view(-1), # fp32 data + p_grad.view(-1), # fp32 grad + exp_avg.view(-1), + exp_avg_sq.view(-1), + loss_scale, + ) + + torch_adam_update( + step, + lr, + beta1, + beta2, + eps, + weight_decay, + True, + p_data_copy, # fp32 data + p_grad_copy, # fp32 grad + exp_avg_copy, + exp_avg_sq_copy, + loss_scale, + use_adamw, + ) + + if loss_scale > 0: + p_grad.div_(loss_scale) + + var = p_data_copy - p_data + data_diff = torch.max(torch.abs(var)) + threshold = 2e-3 if grad_dtype else 1e-4 + self.assertLess( + data_diff, + threshold, + f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps " + f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} loss_scale {loss_scale} grad_dtype {grad_dtype}", + ) + max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad)) + self.assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}") + max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) + self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") + max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) + self.assertTrue( + max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}" + ) + + def test_cpu_adam(self): + lr = 0.9 + eps = 1e-6 + weight_decay = 0 + for use_adamw in [False, True]: + for shape in [(1023, ), (32, 1024)]: + for step in range(1, 2): + for lr in [0.01]: + for eps in [1e-8]: + for beta1 in [0.9]: + for beta2 in [0.999]: + for weight_decay in [0.001]: + for grad_dtype in [torch.half, torch.float]: + for loss_scale in [-1, 2 ** 5]: + self.check_res( + step, + lr, + eps, + beta1, + beta2, + weight_decay, + shape, + grad_dtype, + loss_scale, + use_adamw, + cpu_adam, + ) + + +if __name__ == "__main__": + test = Test() + test.test_cpu_adam() + print('All is well.')