diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h index 978c72fed..005a36ba1 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -10,9 +10,8 @@ using namespace std; -template -class Softmax { - public: +template class Softmax { +public: struct Config { size_t nhead; Config(size_t nhead) : nhead(nhead) {} @@ -35,10 +34,8 @@ class Softmax { stream); } - void reset_size(size_t nhead) { - config_.nhead = nhead; - } + void reset_size(size_t nhead) { config_.nhead = nhead; } - private: +private: Config config_; };