mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-30 17:22:21 +00:00
[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968)
code style
This commit is contained in:
parent
fb5bc6cb28
commit
22d1df224d
@ -13,14 +13,16 @@
|
|||||||
#include "cublas_wrappers.h"
|
#include "cublas_wrappers.h"
|
||||||
#include "kernels.h"
|
#include "kernels.h"
|
||||||
|
|
||||||
template <typename T> class FeedForward {
|
template <typename T>
|
||||||
public:
|
class FeedForward {
|
||||||
|
public:
|
||||||
struct Config {
|
struct Config {
|
||||||
int outputSize;
|
int outputSize;
|
||||||
int inputSize;
|
int inputSize;
|
||||||
std::array<int, 3> gemm_algos;
|
std::array<int, 3> gemm_algos;
|
||||||
Config(int outputs, int inputs)
|
Config(int outputs, int inputs)
|
||||||
: outputSize(outputs), inputSize(inputs),
|
: outputSize(outputs),
|
||||||
|
inputSize(inputs),
|
||||||
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
|
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -61,6 +63,6 @@ public:
|
|||||||
config_.inputSize = inputSize;
|
config_.inputSize = inputSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Config config_;
|
Config config_;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user