diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h index ec963259f..9a43aeec3 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h @@ -13,16 +13,14 @@ #include "cublas_wrappers.h" #include "kernels.h" -template -class FeedForward { - public: +template class FeedForward { +public: struct Config { int outputSize; int inputSize; std::array gemm_algos; Config(int outputs, int inputs) - : outputSize(outputs), - inputSize(inputs), + : outputSize(outputs), inputSize(inputs), gemm_algos(std::array({99, 99, 99})) {} }; @@ -63,6 +61,6 @@ class FeedForward { config_.inputSize = inputSize; } - private: +private: Config config_; };