mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[npu] add npu support for gemini and zero (#5067)
* [npu] setup device utils (#5047) * [npu] add npu device support * [npu] support low level zero * [test] update npu zero plugin test * [hotfix] fix import * [test] recover tests * [npu] gemini support npu (#5052) * [npu] refactor device utils * [gemini] support npu * [example] llama2+gemini support npu * [kernel] add arm cpu adam kernel (#5065) * [kernel] add arm cpu adam * [optim] update adam optimizer * [kernel] arm cpu adam remove bf16 support
This commit is contained in:
@@ -142,6 +142,7 @@ class Adam_Optimizer {
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
||||
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
|
||||
AVX_Data &data) {
|
||||
if (is_half) {
|
||||
@@ -159,6 +160,7 @@ class Adam_Optimizer {
|
||||
SIMD_STORE(ptr, data.data);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||
|
304
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp
Normal file
304
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp
Normal file
@@ -0,0 +1,304 @@
|
||||
#include "cpu_adam_arm.h"
|
||||
|
||||
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
|
||||
void *_exp_avg_sq, size_t _param_size,
|
||||
at::ScalarType param_dtype,
|
||||
at::ScalarType grad_dtype,
|
||||
at::ScalarType exp_avg_dtype,
|
||||
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__aarch64__)
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
|
||||
#endif
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
|
||||
#if defined(__aarch64__)
|
||||
float32x4_t betta1_4 = simd_set(_betta1);
|
||||
float32x4_t betta2_4 = simd_set(_betta2);
|
||||
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||
float32x4_t eps_4 = simd_set(_eps);
|
||||
float32x4_t step_size_4 = simd_set(step_size);
|
||||
float32x4_t weight_decay_4;
|
||||
if (_weight_decay > 0) {
|
||||
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||
}
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
|
||||
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
|
||||
if (loss_scale > 0) {
|
||||
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||
grad_4 = vdivq_f32(grad_4, loss_scale_vec);
|
||||
}
|
||||
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
|
||||
float32x4_t variance_4 =
|
||||
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
|
||||
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
|
||||
}
|
||||
momentum_4 = vmulq_f32(momentum_4, betta1_4);
|
||||
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
|
||||
variance_4 = vmulq_f32(variance_4, betta2_4);
|
||||
grad_4 = vmulq_f32(grad_4, grad_4);
|
||||
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
|
||||
grad_4 = vsqrtq_f32(variance_4);
|
||||
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
|
||||
grad_4 = vdivq_f32(momentum_4, grad_4);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
|
||||
}
|
||||
param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
|
||||
simd_store_offset(_params, param_dtype, param_4, i);
|
||||
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
|
||||
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
for (size_t t = rounded_size; t < _param_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > _param_size) copy_size = _param_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t k = t; k < offset; k++) {
|
||||
float grad = scalar_load_offset(grads, grad_dtype, k);
|
||||
if (loss_scale > 0) {
|
||||
grad /= loss_scale;
|
||||
}
|
||||
float param = scalar_load_offset(_params, param_dtype, k);
|
||||
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
|
||||
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad = param * _weight_decay + grad;
|
||||
}
|
||||
momentum = momentum * _betta1;
|
||||
momentum = grad * betta1_minus1 + momentum;
|
||||
|
||||
variance = variance * _betta2;
|
||||
grad = grad * grad;
|
||||
variance = grad * betta2_minus1 + variance;
|
||||
|
||||
grad = sqrt(variance);
|
||||
grad = grad * _bias_correction2 + _eps;
|
||||
grad = momentum / grad;
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param += w_decay * param;
|
||||
}
|
||||
param = grad * step_size + param;
|
||||
|
||||
scalar_store_offset(_params, param_dtype, param, k);
|
||||
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
|
||||
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
|
||||
void *_exp_avg_sq, size_t _param_size,
|
||||
at::ScalarType param_dtype,
|
||||
at::ScalarType grad_dtype,
|
||||
at::ScalarType exp_avg_dtype,
|
||||
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__aarch64__)
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
|
||||
#endif
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
|
||||
#if defined(__aarch64__)
|
||||
float32x4_t betta1_4 = simd_set(_betta1);
|
||||
float32x4_t betta2_4 = simd_set(_betta2);
|
||||
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||
float32x4_t eps_4 = simd_set(_eps);
|
||||
float32x4_t step_size_4 = simd_set(step_size);
|
||||
float32x4_t weight_decay_4;
|
||||
if (_weight_decay > 0) {
|
||||
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||
}
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
|
||||
float32x4_t grad_4[4];
|
||||
float32x4_t momentum_4[4];
|
||||
float32x4_t variance_4[4];
|
||||
float32x4_t param_4[4];
|
||||
#pragma unroll 4
|
||||
for (int j = 0; j < 4; j++) {
|
||||
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
|
||||
if (loss_scale > 0) {
|
||||
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
|
||||
}
|
||||
momentum_4[j] =
|
||||
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
|
||||
variance_4[j] =
|
||||
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
|
||||
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
|
||||
}
|
||||
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
|
||||
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
|
||||
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
|
||||
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
|
||||
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
|
||||
grad_4[j] = vsqrtq_f32(variance_4[j]);
|
||||
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
|
||||
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
|
||||
}
|
||||
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
|
||||
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
|
||||
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
|
||||
i + SIMD_WIDTH * j);
|
||||
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
|
||||
i + SIMD_WIDTH * j);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
|
||||
scalar_seek_offset(grads, grad_dtype, rounded_size),
|
||||
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
|
||||
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
|
||||
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
|
||||
exp_avg_sq_dtype, loss_scale);
|
||||
}
|
||||
}
|
||||
|
||||
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
|
||||
void *_exp_avg_sq, size_t _param_size,
|
||||
at::ScalarType param_dtype,
|
||||
at::ScalarType grad_dtype,
|
||||
at::ScalarType exp_avg_dtype,
|
||||
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
|
||||
size_t rounded_size = 0;
|
||||
#if defined(__aarch64__)
|
||||
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
|
||||
#endif
|
||||
|
||||
float betta1_minus1 = 1 - _betta1;
|
||||
float betta2_minus1 = 1 - _betta2;
|
||||
float step_size = -1 * _alpha / _bias_correction1;
|
||||
float w_decay = -1 * _alpha * _weight_decay;
|
||||
#if defined(__aarch64__)
|
||||
float32x4_t betta1_4 = simd_set(_betta1);
|
||||
float32x4_t betta2_4 = simd_set(_betta2);
|
||||
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
|
||||
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
|
||||
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
|
||||
float32x4_t eps_4 = simd_set(_eps);
|
||||
float32x4_t step_size_4 = simd_set(step_size);
|
||||
float32x4_t weight_decay_4;
|
||||
if (_weight_decay > 0) {
|
||||
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
|
||||
}
|
||||
|
||||
for (size_t t = 0; t < rounded_size; t += TILE) {
|
||||
size_t copy_size = TILE;
|
||||
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
|
||||
size_t offset = copy_size + t;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
|
||||
float32x4_t grad_4[8];
|
||||
float32x4_t momentum_4[8];
|
||||
float32x4_t variance_4[8];
|
||||
float32x4_t param_4[8];
|
||||
#pragma unroll 4
|
||||
for (int j = 0; j < 8; j++) {
|
||||
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
|
||||
if (loss_scale > 0) {
|
||||
float32x4_t loss_scale_vec = simd_set(loss_scale);
|
||||
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
|
||||
}
|
||||
momentum_4[j] =
|
||||
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
|
||||
variance_4[j] =
|
||||
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
|
||||
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
|
||||
if (_weight_decay > 0 && !_adamw_mode) {
|
||||
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
|
||||
}
|
||||
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
|
||||
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
|
||||
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
|
||||
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
|
||||
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
|
||||
grad_4[j] = vsqrtq_f32(variance_4[j]);
|
||||
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
|
||||
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
|
||||
if (_weight_decay > 0 && _adamw_mode) {
|
||||
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
|
||||
}
|
||||
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
|
||||
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
|
||||
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
|
||||
i + SIMD_WIDTH * j);
|
||||
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
|
||||
i + SIMD_WIDTH * j);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (_param_size > rounded_size) {
|
||||
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
|
||||
scalar_seek_offset(grads, grad_dtype, rounded_size),
|
||||
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
|
||||
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
|
||||
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
|
||||
exp_avg_sq_dtype, loss_scale);
|
||||
}
|
||||
}
|
||||
|
||||
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
|
||||
float epsilon, float weight_decay,
|
||||
bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale) {
|
||||
auto params_c = params.contiguous();
|
||||
auto grads_c = grads.contiguous();
|
||||
auto exp_avg_c = exp_avg.contiguous();
|
||||
auto exp_avg_sq_c = exp_avg_sq.contiguous();
|
||||
|
||||
this->IncrementStep(step, beta1, beta2);
|
||||
this->update_state(lr, epsilon, weight_decay, bias_correction);
|
||||
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
|
||||
exp_avg_sq_c.data_ptr(), params_c.numel(),
|
||||
params_c.scalar_type(), grads_c.scalar_type(),
|
||||
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
|
||||
}
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer")
|
||||
.def(py::init<float, float, float, float, float, bool>())
|
||||
.def("step", &AdamOptimizer::step);
|
||||
}
|
201
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h
Normal file
201
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h
Normal file
@@ -0,0 +1,201 @@
|
||||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
||||
#define TILE (128 * 1024 * 1024)
|
||||
|
||||
#if defined(__aarch64__)
|
||||
#include <arm_neon.h>
|
||||
#define SIMD_WIDTH 4
|
||||
|
||||
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float: {
|
||||
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
|
||||
return vld1q_f32(ptr_f + offset);
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
|
||||
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
|
||||
}
|
||||
// case at::ScalarType::BFloat16: {
|
||||
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
|
||||
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
|
||||
// }
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
|
||||
return simd_load_offset(ptr, dtype, 0);
|
||||
}
|
||||
|
||||
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float: {
|
||||
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
|
||||
vst1q_f32(ptr_f + offset, data);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
|
||||
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
|
||||
break;
|
||||
}
|
||||
// case at::ScalarType::BFloat16: {
|
||||
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
|
||||
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
|
||||
// break;
|
||||
// }
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
|
||||
return simd_store_offset(ptr, dtype, data, 0);
|
||||
}
|
||||
|
||||
inline float32x4_t simd_set(float value) {
|
||||
auto val = static_cast<float32_t>(value);
|
||||
return vdupq_n_f32(val);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
return *(reinterpret_cast<const float *>(ptr) + offset);
|
||||
case at::ScalarType::Half:
|
||||
return static_cast<float>(
|
||||
*(reinterpret_cast<const at::Half *>(ptr) + offset));
|
||||
// case at::ScalarType::BFloat16:
|
||||
// return static_cast<float>(
|
||||
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
*(reinterpret_cast<float *>(ptr) + offset) = data;
|
||||
break;
|
||||
case at::ScalarType::Half:
|
||||
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
|
||||
break;
|
||||
// case at::ScalarType::BFloat16:
|
||||
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
|
||||
size_t offset) {
|
||||
switch (dtype) {
|
||||
case at::ScalarType::Float:
|
||||
return reinterpret_cast<float *>(ptr) + offset;
|
||||
case at::ScalarType::Half:
|
||||
return reinterpret_cast<at::Half *>(ptr) + offset;
|
||||
// case at::ScalarType::BFloat16:
|
||||
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
|
||||
default:
|
||||
AT_ERROR("Unsupported dtype");
|
||||
break;
|
||||
}
|
||||
}
|
||||
#define STEP(SPAN) \
|
||||
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
|
||||
void *_exp_avg_sq, size_t _param_size, \
|
||||
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
|
||||
at::ScalarType exp_avg_dtype, \
|
||||
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
|
||||
|
||||
class AdamOptimizer {
|
||||
private:
|
||||
float _alpha;
|
||||
float _betta1;
|
||||
float _betta2;
|
||||
float _eps;
|
||||
float _weight_decay;
|
||||
|
||||
float _betta1_t;
|
||||
float _betta2_t;
|
||||
size_t _step;
|
||||
|
||||
float _bias_correction1;
|
||||
float _bias_correction2;
|
||||
|
||||
bool _adamw_mode;
|
||||
|
||||
public:
|
||||
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
|
||||
float eps = 1e-8, float weight_decay = 0,
|
||||
bool adamw_mode = true)
|
||||
: _alpha(alpha),
|
||||
_betta1(betta1),
|
||||
_betta2(betta2),
|
||||
_eps(eps),
|
||||
_weight_decay(weight_decay),
|
||||
_betta1_t(1.0),
|
||||
_betta2_t(1.0),
|
||||
_step(0),
|
||||
_adamw_mode(adamw_mode) {}
|
||||
~AdamOptimizer() {}
|
||||
|
||||
STEP(1)
|
||||
STEP(4)
|
||||
STEP(8)
|
||||
inline void IncrementStep(size_t step, float beta1, float beta2) {
|
||||
if (beta1 != _betta1 || beta2 != _betta2) {
|
||||
_step = step;
|
||||
_betta1 = beta1;
|
||||
_betta2 = beta2;
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
} else {
|
||||
_step++;
|
||||
if (_step != step) {
|
||||
_betta1_t = std::pow(_betta1, step);
|
||||
_betta2_t = std::pow(_betta2, step);
|
||||
_step = step;
|
||||
} else {
|
||||
_betta1_t *= _betta1;
|
||||
_betta2_t *= _betta2;
|
||||
}
|
||||
}
|
||||
}
|
||||
inline void update_state(float lr, float epsilon, float weight_decay,
|
||||
bool bias_correction) {
|
||||
_alpha = lr;
|
||||
_eps = epsilon;
|
||||
_weight_decay = weight_decay;
|
||||
|
||||
_bias_correction1 = 1.0f;
|
||||
_bias_correction2 = 1.0f;
|
||||
if (bias_correction == 1) {
|
||||
_bias_correction1 = 1 - _betta1_t;
|
||||
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
|
||||
}
|
||||
}
|
||||
|
||||
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
|
||||
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
|
||||
torch::Tensor &grads, torch::Tensor &exp_avg,
|
||||
torch::Tensor &exp_avg_sq, float loss_scale);
|
||||
};
|
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
||||
class Unpad(torch.autograd.Function):
|
||||
|
Reference in New Issue
Block a user