[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-09-22 17:36:42 +00:00
parent c779f4e0e4
commit 4e2092041f
17 changed files with 126 additions and 126 deletions

View File

@@ -81,11 +81,11 @@ with gr.Blocks(css=CSS) as demo:
)
with gr.Row():
btn = gr.UploadButton("📁", file_types=["file"], file_count="multiple", size="sm")
restart_btn = gr.Button(str("\u21BB"), elem_id="restart-btn", scale=1)
restart_btn = gr.Button(str("\u21bb"), elem_id="restart-btn", scale=1)
txt = gr.Textbox(
scale=8,
show_label=False,
placeholder="Enter text and press enter, or use 📁 to upload files, click \u21BB to clear loaded files and restart chat",
placeholder="Enter text and press enter, or use 📁 to upload files, click \u21bb to clear loaded files and restart chat",
container=True,
autofocus=True,
)

View File

@@ -1,6 +1,6 @@
"""This code is adapted from Alpa
https://github.com/alpa-projects/alpa/
with some changes. """
with some changes."""
import multiprocessing
import time

View File

@@ -1,6 +1,6 @@
"""This code is adapted from Alpa
https://github.com/alpa-projects/alpa/
with some changes. """
with some changes."""
import operator
from dataclasses import dataclass

View File

@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OpenMoE model."""
"""PyTorch OpenMoE model."""
import math
from typing import List, Optional, Tuple, Union

View File

@@ -1,6 +1,6 @@
"""This code is from NVIDIA apex:
https://github.com/NVIDIA/apex
with some changes. """
with some changes."""
import numbers

View File

@@ -1,4 +1,4 @@
""" adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py"""
"""adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py"""
import warnings
from collections import defaultdict

View File

@@ -1,4 +1,4 @@
""" adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py"""
"""adapted from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py"""
import warnings
from typing import List

View File

@@ -1,4 +1,4 @@
""" PyTorch ChatGLM model. """
"""PyTorch ChatGLM model."""
from typing import List, Optional, Tuple

View File

@@ -34,8 +34,8 @@ class PreTrainingDataset:
self.do_whole_word_mask = do_whole_word_mask
self.max_predictions_per_seq = max_predictions_per_seq
self.vocab_words = list(tokenizer.vocab.keys())
self.rec = re.compile("[\u4E00-\u9FA5]")
self.whole_rec = re.compile("##[\u4E00-\u9FA5]")
self.rec = re.compile("[\u4e00-\u9fa5]")
self.whole_rec = re.compile("##[\u4e00-\u9fa5]")
self.mlm_p = 0.15
self.mlm_mask_p = 0.8

View File

@@ -75,15 +75,15 @@ auto get_new_segment(
return new_segment;
}
bool startsWith(const std::string &s, const std::string &sub) {
bool startsWith(const std::string& s, const std::string& sub) {
return s.find(sub) == 0 ? true : false;
}
auto create_whole_masked_lm_predictions(
std::vector<std::string> &tokens,
const std::vector<std::string> &original_tokens,
const std::vector<std::string> &vocab_words,
std::map<std::string, int> &vocab, const int max_predictions_per_seq,
std::vector<std::string>& tokens,
const std::vector<std::string>& original_tokens,
const std::vector<std::string>& vocab_words,
std::map<std::string, int>& vocab, const int max_predictions_per_seq,
const double masked_lm_prob) {
// for (auto item : vocab) {
// std::cout << "key=" << std::string(py::str(item.first)) << ", "

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch DeBERTa-v2 model."""
"""PyTorch DeBERTa-v2 model."""
import math
from collections.abc import Sequence

View File

@@ -1,7 +1,7 @@
#include "cpu_adam_arm.h"
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
void AdamOptimizer::Step_1(void* _params, void* grads, void* _exp_avg,
void* _exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
@@ -106,8 +106,8 @@ void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
}
}
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
void AdamOptimizer::Step_4(void* _params, void* grads, void* _exp_avg,
void* _exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
@@ -192,8 +192,8 @@ void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
}
}
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
void AdamOptimizer::Step_8(void* _params, void* grads, void* _exp_avg,
void* _exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
@@ -279,9 +279,9 @@ void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
bool bias_correction, torch::Tensor& params,
torch::Tensor& grads, torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();

View File

@@ -11,15 +11,15 @@
#include <arm_neon.h>
#define SIMD_WIDTH 4
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
inline float32x4_t simd_load_offset(const void* ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
auto ptr_f = reinterpret_cast<const float32_t*>(ptr);
return vld1q_f32(ptr_f + offset);
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
auto ptr_h = reinterpret_cast<const float16_t*>(ptr);
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
}
// case at::ScalarType::BFloat16: {
@@ -31,20 +31,20 @@ inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
break;
}
}
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
inline float32x4_t simd_load(void const* ptr, at::ScalarType dtype) {
return simd_load_offset(ptr, dtype, 0);
}
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
inline void simd_store_offset(void* ptr, at::ScalarType dtype, float32x4_t data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
auto ptr_f = reinterpret_cast<float32_t*>(ptr);
vst1q_f32(ptr_f + offset, data);
break;
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
auto ptr_h = reinterpret_cast<float16_t*>(ptr);
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
break;
}
@@ -59,7 +59,7 @@ inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
}
}
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
inline void simd_store(void* ptr, at::ScalarType dtype, float32x4_t data) {
return simd_store_offset(ptr, dtype, data, 0);
}
@@ -70,14 +70,14 @@ inline float32x4_t simd_set(float value) {
#endif
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
inline float scalar_load_offset(const void* ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return *(reinterpret_cast<const float *>(ptr) + offset);
return *(reinterpret_cast<const float*>(ptr) + offset);
case at::ScalarType::Half:
return static_cast<float>(
*(reinterpret_cast<const at::Half *>(ptr) + offset));
*(reinterpret_cast<const at::Half*>(ptr) + offset));
// case at::ScalarType::BFloat16:
// return static_cast<float>(
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
@@ -87,14 +87,14 @@ inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
}
}
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
inline void scalar_store_offset(void* ptr, at::ScalarType dtype, float data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
*(reinterpret_cast<float *>(ptr) + offset) = data;
*(reinterpret_cast<float*>(ptr) + offset) = data;
break;
case at::ScalarType::Half:
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
*(reinterpret_cast<at::Half*>(ptr) + offset) = data;
break;
// case at::ScalarType::BFloat16:
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
@@ -105,13 +105,13 @@ inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
}
}
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
inline void* scalar_seek_offset(void* ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return reinterpret_cast<float *>(ptr) + offset;
return reinterpret_cast<float*>(ptr) + offset;
case at::ScalarType::Half:
return reinterpret_cast<at::Half *>(ptr) + offset;
return reinterpret_cast<at::Half*>(ptr) + offset;
// case at::ScalarType::BFloat16:
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
default:
@@ -120,8 +120,8 @@ inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
}
}
#define STEP(SPAN) \
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
void *_exp_avg_sq, size_t _param_size, \
void Step_##SPAN(void* _params, void* grads, void* _exp_avg, \
void* _exp_avg_sq, size_t _param_size, \
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
at::ScalarType exp_avg_dtype, \
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
@@ -195,7 +195,7 @@ class AdamOptimizer {
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
float weight_decay, bool bias_correction, torch::Tensor& params,
torch::Tensor& grads, torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq, float loss_scale);
};

View File

@@ -9,36 +9,36 @@ namespace cuda {
namespace utils {
template <typename T, int VecSize>
__device__ __inline__ void copy_zero(T *dst) {
__device__ __inline__ void copy_zero(T* dst) {
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
*(reinterpret_cast<VT*>(dst)) = funcs::CastFunctor<float, VT>()(0.0f);
}
template <typename SrcT, typename DstT, int VecSize>
__device__ __inline__ void copy(const SrcT *src, DstT *dst) {
__device__ __inline__ void copy(const SrcT* src, DstT* dst) {
using SrcVT = typename common::VecTypeTrait<SrcT, VecSize>::Type;
using DstVT = typename common::VecTypeTrait<DstT, VecSize>::Type;
*(reinterpret_cast<DstVT *>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
*(reinterpret_cast<const SrcVT *>(src)));
*(reinterpret_cast<DstVT*>(dst)) = funcs::CastFunctor<SrcVT, DstVT>()(
*(reinterpret_cast<const SrcVT*>(src)));
}
template <typename T, int VecSize>
__device__ __inline__ void copy(const T *src, T *dst) {
__device__ __inline__ void copy(const T* src, T* dst) {
using VT = typename common::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
*(reinterpret_cast<VT*>(dst)) = *(reinterpret_cast<const VT*>(src));
}
template <>
__device__ __inline__ void copy<float, float, 8>(const float *src, float *dst) {
__device__ __inline__ void copy<float, float, 8>(const float* src, float* dst) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
*(reinterpret_cast<float4 *>(dst + 4)) =
*(reinterpret_cast<const float4 *>(src + 4));
*(reinterpret_cast<float4*>(dst)) = *(reinterpret_cast<const float4*>(src));
*(reinterpret_cast<float4*>(dst + 4)) =
*(reinterpret_cast<const float4*>(src + 4));
}
template <typename T>
int get_vec_size(const torch::Tensor &tensor) {
int get_vec_size(const torch::Tensor& tensor) {
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr());
const int max_aligned_size = 128;
const int dtype_size = sizeof(T) * 8;

View File

@@ -32,8 +32,8 @@ SOFTWARE
// C++ interface
void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
void Adam_Optimizer::Step_1(float* _params, float* grads, float* _exp_avg,
float* _exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
@@ -44,10 +44,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
__half* params_cast_h = reinterpret_cast<__half*>(_params);
__half* grads_cast_h = reinterpret_cast<__half*>(grads);
__half* momentum_cast_h = reinterpret_cast<__half*>(_exp_avg);
__half* variance_cast_h = reinterpret_cast<__half*>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
@@ -182,17 +182,17 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
}
}
void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
void Adam_Optimizer::Step_4(float* _params, float* grads, float* _exp_avg,
float* _exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
__half* params_cast_h = reinterpret_cast<__half*>(_params);
__half* grads_cast_h = reinterpret_cast<__half*>(grads);
__half* momentum_cast_h = reinterpret_cast<__half*>(_exp_avg);
__half* variance_cast_h = reinterpret_cast<__half*>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
@@ -285,29 +285,29 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
}
#endif
if (_param_size > rounded_size)
Step_1((param_half_precision ? (float *)(params_cast_h + rounded_size)
Step_1((param_half_precision ? (float*)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
(grad_half_precision ? (float*)(grads_cast_h + rounded_size)
: grads + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
(momentum_half_precision ? (float*)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
(variance_half_precision ? (float*)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, momentum_half_precision,
variance_half_precision, loss_scale);
}
void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
float *_exp_avg_sq, size_t _param_size,
void Adam_Optimizer::Step_8(float* _params, float* grads, float* _exp_avg,
float* _exp_avg_sq, size_t _param_size,
bool param_half_precision, bool grad_half_precision,
bool momentum_half_precision,
bool variance_half_precision, float loss_scale) {
size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
__half *params_cast_h = reinterpret_cast<__half *>(_params);
__half *grads_cast_h = reinterpret_cast<__half *>(grads);
__half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg);
__half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq);
__half* params_cast_h = reinterpret_cast<__half*>(_params);
__half* grads_cast_h = reinterpret_cast<__half*>(grads);
__half* momentum_cast_h = reinterpret_cast<__half*>(_exp_avg);
__half* variance_cast_h = reinterpret_cast<__half*>(_exp_avg_sq);
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
AVX_Data betta1_4;
@@ -400,13 +400,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
}
#endif
if (_param_size > rounded_size)
Step_4((param_half_precision ? (float *)(params_cast_h + rounded_size)
Step_4((param_half_precision ? (float*)(params_cast_h + rounded_size)
: _params + rounded_size),
(grad_half_precision ? (float *)(grads_cast_h + rounded_size)
(grad_half_precision ? (float*)(grads_cast_h + rounded_size)
: grads + rounded_size),
(momentum_half_precision ? (float *)(momentum_cast_h + rounded_size)
(momentum_half_precision ? (float*)(momentum_cast_h + rounded_size)
: _exp_avg + rounded_size),
(variance_half_precision ? (float *)(variance_cast_h + rounded_size)
(variance_half_precision ? (float*)(variance_cast_h + rounded_size)
: _exp_avg_sq + rounded_size),
(_param_size - rounded_size), param_half_precision,
grad_half_precision, momentum_half_precision,
@@ -415,18 +415,18 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
bool bias_correction, torch::Tensor& params,
torch::Tensor& grads, torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float *params_ptr = (float *)params_c.data_ptr();
float *grads_ptr = (float *)grads_c.data_ptr();
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);

View File

@@ -49,9 +49,9 @@ SOFTWARE
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
_mm256_storeu_ps((float*)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__)
@@ -65,9 +65,9 @@ SOFTWARE
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
_mm_storeu_ps((float*)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \
d, _MM_FROUND_TO_NEAREST_INT)))
#endif
@@ -85,7 +85,7 @@ union AVX_Data {
#define STEP(SPAN) \
void Step_##SPAN( \
float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \
float* _params, float* grads, float* _exp_avg, float* _exp_avg_sq, \
size_t _param_size, bool param_half_precision = false, \
bool grad_half_precision = false, bool momentum_half_precision = false, \
bool variance_half_precision = false, float loss_scale = -1);
@@ -143,8 +143,8 @@ class Adam_Optimizer {
}
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
inline void simd_load(bool is_half, float* ptr, __half* h_ptr,
AVX_Data& data) {
if (is_half) {
data.data = SIMD_LOAD_HALF(h_ptr);
} else {
@@ -152,8 +152,8 @@ class Adam_Optimizer {
}
}
inline void simd_store(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
inline void simd_store(bool is_half, float* ptr, __half* h_ptr,
AVX_Data& data) {
if (is_half) {
SIMD_STORE_HALF(h_ptr, data.data);
} else {
@@ -163,9 +163,9 @@ class Adam_Optimizer {
#endif
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
float weight_decay, bool bias_correction, torch::Tensor& params,
torch::Tensor& grads, torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq, float loss_scale);
private:
float _alpha;

View File

@@ -11,8 +11,8 @@
namespace {
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int& n1,
int& n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
@@ -31,8 +31,8 @@ void check_args(at::IntArrayRef normalized_shape, at::Tensor gamma,
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
int &n2) {
void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int& n1,
int& n2) {
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
@@ -63,16 +63,16 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
}
void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, int &n1, int &n2) {
at::Tensor gamma, at::Tensor beta, int& n1, int& n2) {
check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta);
}
} // namespace
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *input, int n1, int n2,
at::IntArrayRef normalized_shape, at::Tensor *gamma,
at::Tensor *beta, double epsilon);
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar,
at::Tensor* input, int n1, int n2,
at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
@@ -103,12 +103,12 @@ std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
at::Tensor *invvar, at::Tensor *input, int n1,
void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean,
at::Tensor* invvar, at::Tensor* input, int n1,
int n2, at::IntArrayRef normalized_shape,
at::Tensor *gamma, at::Tensor *beta,
double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta);
at::Tensor* gamma, at::Tensor* beta,
double epsilon, at::Tensor* grad_input,
at::Tensor* grad_gamma, at::Tensor* grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input,