From d8d07b0e2becc88f80590f0df44d3bb952badf37 Mon Sep 17 00:00:00 2001 From: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Date: Fri, 13 May 2022 21:56:06 +0800 Subject: [PATCH] [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) --- .../csrc/multihead_attention_1d.cpp | 228 ++++++++++-------- 1 file changed, 132 insertions(+), 96 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp index 63bf633f5..b02556f79 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp @@ -10,8 +10,9 @@ #include "kernels.h" template -MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, int max_seq_len, - int hidden_size, int num_heads, +MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, + int max_seq_len, int hidden_size, + int num_heads, float attn_prob_dropout_ratio, float hidden_output_dropout_ratio, bool pre_or_postLayerNorm) @@ -22,18 +23,22 @@ MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, in _heads(num_heads), _training(true), _pre_or_postLayerNorm(pre_or_postLayerNorm), - _qkv_linear(typename FeedForward::Config(3 * hidden_size, hidden_size)), - _attn_out_linear(typename FeedForward::Config(hidden_size, hidden_size)), - _attn_ln(typename Normalize_Layer::Config(hidden_size, false), _max_batch_tokens), + _qkv_linear( + typename FeedForward::Config(3 * hidden_size, hidden_size)), + _attn_out_linear( + typename FeedForward::Config(hidden_size, hidden_size)), + _attn_ln(typename Normalize_Layer::Config(hidden_size, false), + _max_batch_tokens), _softmax(typename Softmax::Config(num_heads)), _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), _max_batch_tokens * _heads * _max_seq_len), _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), _max_batch_tokens * _hidden_size), - _attn_scores(typename StridedBatchGemm::Config((T(1.0) / T(sqrt(_hidden_size / _heads))), - T(0.0), CUBLAS_OP_T, CUBLAS_OP_N)), - _attn_context( - typename StridedBatchGemm::Config(T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { + _attn_scores(typename StridedBatchGemm::Config( + (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, + CUBLAS_OP_N)), + _attn_context(typename StridedBatchGemm::Config( + T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { assert(_hidden_size % _heads == 0); } @@ -43,43 +48,52 @@ MultiHeadAttention::~MultiHeadAttention() { } template -void MultiHeadAttention::attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, +void MultiHeadAttention::attn_layer_fw(const T *input_ptr, + const T *input_mask_ptr, T *output_ptr, T *buffer) { T *q_tf_ptr = _qkv_ptr; T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; if (_pre_or_postLayerNorm) { - _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, - _stream); + _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, + _batch_tokens, _stream); } - const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + const T *gemmQKV_inp_ptr = + _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, _cublasHandle); + _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, + _cublasHandle); - launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, _batch_size, _seq_len, 3, - _heads / pg_size, _hidden_size / _heads, _stream); + launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, + _batch_size, _seq_len, 3, _heads / pg_size, + _hidden_size / _heads, _stream); // attention scores, q*k - _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle); + _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, + _cublasHandle); // Softmax + Mask _softmax.reset_size(_heads / pg_size); - _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, _seq_len, _stream, true); + _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, + _seq_len, _stream, true); // attn prob dropout. - _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, _batch_heads * _seq_len * _seq_len, - _stream); + _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, + _batch_heads * _seq_len * _seq_len, _stream); // attention context, score * v - _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle); + _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, + _cublasHandle); // [b, nh, s, ad] -> [b, s, nh, ad] - launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, _hidden_size / pg_size, - _heads / pg_size, 1, _stream); + launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, + _hidden_size / pg_size, _heads / pg_size, 1, + _stream); _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, output_ptr, _cublasHandle); + _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, + output_ptr, _cublasHandle); // allreduce if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { @@ -88,24 +102,27 @@ void MultiHeadAttention::attn_layer_fw(const T *input_ptr, const T *input_mas if (typeid(T) != typeid(float)) { data_type = torch::kHalf; } - auto output_tensor = - torch::from_blob(output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); + auto output_tensor = torch::from_blob( + output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::TensorOptions(torch::kCUDA).dtype(data_type)); std::vector allreduce_tensors = {output_tensor}; auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); work->wait(); } - _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, _attn_ob_ptr, - _batch_tokens, _hidden_size, _stream); + _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, + _attn_ob_ptr, _batch_tokens, _hidden_size, + _stream); if (!_pre_or_postLayerNorm) { // in-place ln since ln-input will not be used in post-ln mode - _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, _stream); + _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, + _batch_tokens, _stream); } } template -void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr) { +void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, + T *out_ptr) { _stream = Context::Instance().get_stream(); _cublasHandle = Context::Instance().get_cublashandle(); T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim @@ -114,8 +131,11 @@ void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, } template -void MultiHeadAttention::attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr, - const T *grad_output_ptr, T *grad_input_ptr, T *buffer) { +void MultiHeadAttention::attn_layer_bw(const T *input_ptr, + const T *input_mask_ptr, + const T *output_ptr, + const T *grad_output_ptr, + T *grad_input_ptr, T *buffer) { cudaStream_t streams[2] = {_stream, _stream}; const T *q_tf_ptr = _qkv_ptr; @@ -137,45 +157,57 @@ void MultiHeadAttention::attn_layer_bw(const T *input_ptr, const T *input_mas // batch_size * head_num * seq_len * seq_len); if (_pre_or_postLayerNorm) { - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_output_ptr, - _batch_tokens, _hidden_size, _stream); + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, + grad_output_ptr, _batch_tokens, + _hidden_size, _stream); } else { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, grad_output_ptr, - nullptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, grad_residual_ptr, - _batch_tokens, _hidden_size, _stream); + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, + grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr, + _attn_nb_ptr, _batch_tokens, streams); + _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, + grad_residual_ptr, _batch_tokens, + _hidden_size, _stream); } // bw of output project _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, _attn_ow_ptr, - _grad_attn_ow_ptr, _grad_attn_ob_ptr, _cublasHandle, _stream, - grad_input_buf_ptr, nullptr, false); - launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, _seq_len, - _hidden_size / pg_size, _heads / pg_size, _stream); + _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, + _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr, + _cublasHandle, _stream, grad_input_buf_ptr, nullptr, + false); + launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, + _seq_len, _hidden_size / pg_size, _heads / pg_size, + _stream); // bw of score * v - _attn_context.Backward(_batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, - grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); + _attn_context.Backward( + _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, + grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); - _attn_prob_dropout.d_dropout(grad_softmax_ptr, _batch_heads * _seq_len * _seq_len, _stream); + _attn_prob_dropout.d_dropout(grad_softmax_ptr, + _batch_heads * _seq_len * _seq_len, _stream); _softmax.reset_size(_heads / pg_size); - _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, _seq_len, _stream); + _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, + _seq_len, _stream); // bw of q * k - _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle, - grad_qkv_5d_ptr + _batch_dim / pg_size, grad_qkv_5d_ptr); + _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, + _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size, + grad_qkv_5d_ptr); // [3, b, nh, s, ad] -> [b, s, 3, h] - launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, _seq_len, - _hidden_size / pg_size, _heads / pg_size, 3, _stream); + launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, + _seq_len, _hidden_size / pg_size, _heads / pg_size, + 3, _stream); - const T *gemmQKV_inp_ptr = _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; + const T *gemmQKV_inp_ptr = + _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, _attn_qkvw_ptr, - _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, _cublasHandle, _stream, - grad_input_buf_ptr, nullptr, true); + _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, + _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, + _cublasHandle, _stream, grad_input_buf_ptr, nullptr, + true); // allreduce if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { @@ -185,7 +217,8 @@ void MultiHeadAttention::attn_layer_bw(const T *input_ptr, const T *input_mas data_type = torch::kHalf; } auto grad_input_tensor = - torch::from_blob(grad_input_buf_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, + torch::from_blob(grad_input_buf_ptr, + {int(_batch_size), int(_seq_len), int(_hidden_size)}, torch::TensorOptions(torch::kCUDA).dtype(data_type)); std::vector allreduce_tensors = {grad_input_tensor}; auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); @@ -193,19 +226,21 @@ void MultiHeadAttention::attn_layer_bw(const T *input_ptr, const T *input_mas } if (_pre_or_postLayerNorm) { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, grad_input_buf_ptr, - grad_output_ptr, gemmQKV_inp_ptr, _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, - streams); + _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, + grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr, + _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); } else { // FIXME later - launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, _batch_size, - _seq_len, _hidden_size, _stream); + launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, + _batch_size, _seq_len, _hidden_size, _stream); } } template -void MultiHeadAttention::Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr, - const T *input_mask_ptr, T *grad_input_ptr) { +void MultiHeadAttention::Backward(const T *grad_output_ptr, + const T *input_ptr, const T *output_ptr, + const T *input_mask_ptr, + T *grad_input_ptr) { _stream = Context::Instance().get_stream(); _cublasHandle = Context::Instance().get_cublashandle(); T *buffer = _shared_mem_ptr; @@ -215,7 +250,8 @@ void MultiHeadAttention::Backward(const T *grad_output_ptr, const T *input_pt 4 * _batch_dim + max(3 * _batch_dim, _batch_size * _head_num * _seq_len * _seq_len); */ - attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, grad_input_ptr, buffer); + attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, + grad_input_ptr, buffer); } template @@ -233,7 +269,8 @@ template class MultiHeadAttention<__half>; // x is torch::Tensor #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CONTIGUOUS(x) \ + AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) @@ -241,15 +278,17 @@ template class MultiHeadAttention<__half>; static std::unordered_map> s_multihead_attention; template -int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_len, int hidden_dim, - int num_heads, float attn_prob_dropout_ratio, - float hidden_dropout_ratio, bool pre_or_postLayerNorm, +int create_multihead_attention(int layer_id, int max_batch_tokens, + int max_seq_len, int hidden_dim, int num_heads, + float attn_prob_dropout_ratio, + float hidden_dropout_ratio, + bool pre_or_postLayerNorm, c10::intrusive_ptr pg_) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); Context::Instance().set_stream(stream); auto layer = std::make_shared>( - layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, attn_prob_dropout_ratio, - hidden_dropout_ratio, pre_or_postLayerNorm); + layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, + attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm); layer->SetPG(pg_); @@ -261,15 +300,12 @@ int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_l } template -std::vector multihead_attention_fw(int layer_id, const torch::Tensor &input, - const torch::Tensor &input_mask, - const torch::Tensor &in_proj_weight, - const torch::Tensor &in_proj_bias, - const torch::Tensor &out_proj_weight, - const torch::Tensor &out_proj_bias, - const torch::Tensor &norm_weight, - const torch::Tensor &norm_bias, - bool training_mode, bool prelayernorm) { +std::vector multihead_attention_fw( + int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask, + const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias, + const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias, + const torch::Tensor &norm_weight, const torch::Tensor &norm_bias, + bool training_mode, bool prelayernorm) { CHECK_INPUT(input); CHECK_INPUT(input_mask); @@ -280,7 +316,8 @@ std::vector multihead_attention_fw(int layer_id, const torch::Ten T *out_ptr = (T *)output.data_ptr(); std::shared_ptr> layer = - std::static_pointer_cast>(s_multihead_attention[layer_id]); + std::static_pointer_cast>( + s_multihead_attention[layer_id]); layer->set_cur_batch_shape(input.size(0), input.size(1)); layer->SetTrainingMode(training_mode); @@ -297,17 +334,13 @@ std::vector multihead_attention_fw(int layer_id, const torch::Ten } template -std::vector multihead_attention_bw(int layer_id, - const torch::Tensor &grad_dec_output, - const torch::Tensor &output, - const torch::Tensor &input, - const torch::Tensor &input_mask, - const torch::Tensor &in_proj_weight, - const torch::Tensor &in_proj_bias, - const torch::Tensor &out_proj_weight, - const torch::Tensor &out_proj_bias, - const torch::Tensor &norm_weight, - const torch::Tensor &norm_bias) { +std::vector multihead_attention_bw( + int layer_id, const torch::Tensor &grad_dec_output, + const torch::Tensor &output, const torch::Tensor &input, + const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight, + const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, + const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight, + const torch::Tensor &norm_bias) { auto g_output = grad_dec_output.contiguous(); CHECK_INPUT(g_output); CHECK_INPUT(output); @@ -332,7 +365,8 @@ std::vector multihead_attention_bw(int layer_id, T *grad_input_ptr = (T *)grad_input.data_ptr(); std::shared_ptr> layer = - std::static_pointer_cast>(s_multihead_attention[layer_id]); + std::static_pointer_cast>( + s_multihead_attention[layer_id]); layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); @@ -342,10 +376,12 @@ std::vector multihead_attention_bw(int layer_id, layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); - layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, grad_input_ptr); + layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, + grad_input_ptr); - return {grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, - grad_out_proj_bias, grad_norm_weight, grad_norm_bias}; + return {grad_input, grad_in_proj_weight, grad_in_proj_bias, + grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight, + grad_norm_bias}; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {