[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968)

code style
This commit is contained in:
ExtremeViscent 2022-05-16 03:20:48 +01:00 committed by binmakeswell
parent fb5bc6cb28
commit 22d1df224d

View File

@ -13,14 +13,16 @@
#include "cublas_wrappers.h" #include "cublas_wrappers.h"
#include "kernels.h" #include "kernels.h"
template <typename T> class FeedForward { template <typename T>
public: class FeedForward {
public:
struct Config { struct Config {
int outputSize; int outputSize;
int inputSize; int inputSize;
std::array<int, 3> gemm_algos; std::array<int, 3> gemm_algos;
Config(int outputs, int inputs) Config(int outputs, int inputs)
: outputSize(outputs), inputSize(inputs), : outputSize(outputs),
inputSize(inputs),
gemm_algos(std::array<int, 3>({99, 99, 99})) {} gemm_algos(std::array<int, 3>({99, 99, 99})) {}
}; };
@ -61,6 +63,6 @@ public:
config_.inputSize = inputSize; config_.inputSize = inputSize;
} }
private: private:
Config config_; Config config_;
}; };