mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-03 04:39:43 +00:00
* [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) * [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> * [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943) * [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945) * [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946) Co-authored-by: jnbai <897086360@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong <jiatong.han@u.nus.edu> * [NFC] polish colossalai/builder/pipeline.py code style (#951) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953) Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954) * [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956) Co-authored-by: RichardoLuo <14049555596@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962) * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963) Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964) * [NFC] polish __init__.py code style (#965) * [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968) code style * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970) * [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972) * [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977) * [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979) * [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980) * [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com> Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com> Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com> Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com> Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com> Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com> Co-authored-by: jnbai <897086360@qq.com> Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com> Co-authored-by: Jiatong <jiatong.han@u.nus.edu> Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com> Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com> Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com> Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Co-authored-by: RichardoLuo <14049555596@qq.com> Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com> Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com> Co-authored-by: MaxT <854721132@qq.com> Co-authored-by: superhao1995 <804673818@qq.com> Co-authored-by: ziyu huang <huang0ziyu@gmail.com> Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com> Co-authored-by: lucasliunju <lucasliunju@gmail.com> Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com> Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zirui Zhu <zhuzr21@gmail.com> Co-authored-by: Ofey Chan <ofey206@gmail.com> Co-authored-by: DouJS <dujiangsu@163.com> Co-authored-by: Jie Zhu <chore.08-protist@icloud.com> Co-authored-by: shenggan <csg19971016@gmail.com> Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
160 lines
4.8 KiB
C++
160 lines
4.8 KiB
C++
/*
|
|
Copyright (c) Microsoft Corporation.
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE
|
|
*/
|
|
#pragma once
|
|
|
|
#include <cublas_v2.h>
|
|
#include <cuda.h>
|
|
#include <cuda_fp16.h>
|
|
#include <cuda_runtime_api.h>
|
|
#include <stdio.h>
|
|
|
|
#if (__x86_64__ || __i386__)
|
|
#include <cpuid.h>
|
|
#include <x86intrin.h>
|
|
#endif
|
|
|
|
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
|
|
#define TILE (128 * 1024 * 1024)
|
|
|
|
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
|
|
|
|
#if defined(__AVX512__)
|
|
#define SIMD_WIDTH 16
|
|
#define INTV __m256i
|
|
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
|
|
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
|
|
#define SIMD_SET(x) _mm512_set1_ps(x)
|
|
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
|
|
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
|
|
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
|
|
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
|
|
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
|
|
#define SIMD_LOAD_HALF(x) \
|
|
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
|
|
#define SIMD_STORE_HALF(x, d) \
|
|
_mm256_store_ps( \
|
|
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
|
|
|
#elif defined(__AVX256__) or defined(__AVX2__)
|
|
#define SIMD_WIDTH 8
|
|
#define INTV __m128i
|
|
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
|
|
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
|
|
#define SIMD_SET(x) _mm256_set1_ps(x)
|
|
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
|
|
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
|
|
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
|
|
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
|
|
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
|
|
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
|
#define SIMD_STORE_HALF(x, d) \
|
|
_mm_store_ps( \
|
|
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
|
|
|
|
#endif
|
|
|
|
union AVX_Data {
|
|
#if defined(__AVX512__)
|
|
__m512 data;
|
|
#elif defined(__AVX256__) or defined(__AVX2__)
|
|
__m256 data;
|
|
#endif
|
|
// float data_f[16];
|
|
};
|
|
|
|
#endif
|
|
|
|
#define STEP(SPAN) \
|
|
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
|
|
float *_exp_avg_sq, size_t _param_size, \
|
|
bool param_half_precision = false, \
|
|
bool grad_half_precision = false, float loss_scale = -1);
|
|
|
|
class Adam_Optimizer {
|
|
public:
|
|
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
|
|
float eps = 1e-8, float weight_decay = 0,
|
|
bool adamw_mode = true)
|
|
: _alpha(alpha),
|
|
_betta1(betta1),
|
|
_betta2(betta2),
|
|
_eps(eps),
|
|
_weight_decay(weight_decay),
|
|
_betta1_t(1.0),
|
|
_betta2_t(1.0),
|
|
_step(0),
|
|
_adamw_mode(adamw_mode) {}
|
|
~Adam_Optimizer() {}
|
|
|
|
STEP(1)
|
|
STEP(4)
|
|
STEP(8)
|
|
inline void IncrementStep(size_t step, float beta1, float beta2) {
|
|
if (beta1 != _betta1 || beta2 != _betta2) {
|
|
_step = step;
|
|
_betta1 = beta1;
|
|
_betta2 = beta2;
|
|
_betta1_t = std::pow(_betta1, step);
|
|
_betta2_t = std::pow(_betta2, step);
|
|
} else {
|
|
_step++;
|
|
if (_step != step) {
|
|
_betta1_t = std::pow(_betta1, step);
|
|
_betta2_t = std::pow(_betta2, step);
|
|
_step = step;
|
|
} else {
|
|
_betta1_t *= _betta1;
|
|
_betta2_t *= _betta2;
|
|
}
|
|
}
|
|
}
|
|
inline void update_state(float lr, float epsilon, float weight_decay,
|
|
bool bias_correction) {
|
|
_alpha = lr;
|
|
_eps = epsilon;
|
|
_weight_decay = weight_decay;
|
|
|
|
_bias_correction1 = 1.0f;
|
|
_bias_correction2 = 1.0f;
|
|
if (bias_correction == 1) {
|
|
_bias_correction1 = 1 - _betta1_t;
|
|
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
|
|
}
|
|
}
|
|
|
|
private:
|
|
float _alpha;
|
|
float _betta1;
|
|
float _betta2;
|
|
float _eps;
|
|
float _weight_decay;
|
|
|
|
float _betta1_t;
|
|
float _betta2_t;
|
|
size_t _step;
|
|
|
|
float _bias_correction1;
|
|
float _bias_correction2;
|
|
|
|
bool _adamw_mode;
|
|
};
|