fix format (#611)

This commit is contained in:
Yuer867 2022-04-01 14:19:27 +08:00 committed by binmakeswell
parent d3d5bedc65
commit 5ecef13c16
2 changed files with 5 additions and 6 deletions

View File

@ -9,7 +9,7 @@
#include "cuda_util.h" #include "cuda_util.h"
class Context { class Context {
public: public:
Context() : _stream(nullptr) { Context() : _stream(nullptr) {
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
} }
@ -30,7 +30,7 @@ class Context {
cublasHandle_t get_cublashandle() { return _cublasHandle; } cublasHandle_t get_cublashandle() { return _cublasHandle; }
private: private:
cudaStream_t _stream; cudaStream_t _stream;
cublasHandle_t _cublasHandle; cublasHandle_t _cublasHandle;
}; };

View File

@ -8,9 +8,8 @@
#include "cuda_util.h" #include "cuda_util.h"
template <typename T> template <typename T> class CrossEntropyLayer {
class CrossEntropyLayer { public:
public:
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
virtual ~CrossEntropyLayer(); virtual ~CrossEntropyLayer();
@ -23,7 +22,7 @@ class CrossEntropyLayer {
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
private: private:
void allocate_mem_buffer() { void allocate_mem_buffer() {
// allocate local gpu memory // allocate local gpu memory
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2); _loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);