[NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962)

This commit is contained in:
MaxT 2022-05-15 09:01:08 +08:00 committed by binmakeswell
parent 89e2767a92
commit 442a2975ab

View File

@ -19,21 +19,25 @@
template <typename T> template <typename T>
class MultiHeadAttention { class MultiHeadAttention {
public: public:
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size, MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio, int hidden_size, int num_heads, float attn_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm); bool pre_or_postLayerNorm);
virtual ~MultiHeadAttention(); virtual ~MultiHeadAttention();
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, void Backward(const T *grad_output_ptr, const T *input_ptr,
const T *input_mask_ptr, T *grad_input_ptr); const T *output_ptr, const T *input_mask_ptr,
T *grad_input_ptr);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer); void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer); const T *output_ptr, const T *grad_output_ptr,
T *grad_input_attn_layer_bwptr, T *buffer);
void set_cur_batch_shape(int batch_size, int seq_len) { void set_cur_batch_shape(int batch_size, int seq_len) {
_batch_size = batch_size; _batch_size = batch_size;
@ -83,14 +87,17 @@ class MultiHeadAttention {
} }
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3); _qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
_soft_out_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len); _soft_out_ptr =
_ctx_bufB_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len); cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size); _attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
// buffer size needed by attn bw // buffer size needed by attn bw
size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size + size_t smem_size =
std::max(3 * _max_batch_tokens * _hidden_size / pg_size, 4 * _max_batch_tokens * _hidden_size / pg_size +
_max_batch_tokens * _heads / pg_size * _max_seq_len); std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);
if (!_shared_mem_ptr) { if (!_shared_mem_ptr) {
cuda_free(_shared_mem_ptr); cuda_free(_shared_mem_ptr);