diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h index 9a43aeec3..ec963259f 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h @@ -13,14 +13,16 @@ #include "cublas_wrappers.h" #include "kernels.h" -template class FeedForward { -public: +template +class FeedForward { + public: struct Config { int outputSize; int inputSize; std::array gemm_algos; Config(int outputs, int inputs) - : outputSize(outputs), inputSize(inputs), + : outputSize(outputs), + inputSize(inputs), gemm_algos(std::array({99, 99, 99})) {} }; @@ -61,6 +63,6 @@ public: config_.inputSize = inputSize; } -private: + private: Config config_; };