mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-05 05:32:03 +00:00
[NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962)
This commit is contained in:
parent
89e2767a92
commit
442a2975ab
@ -19,21 +19,25 @@
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
class MultiHeadAttention {
|
class MultiHeadAttention {
|
||||||
public:
|
public:
|
||||||
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size,
|
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
|
||||||
int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio,
|
int hidden_size, int num_heads, float attn_dropout_ratio,
|
||||||
|
float hidden_output_dropout_ratio,
|
||||||
bool pre_or_postLayerNorm);
|
bool pre_or_postLayerNorm);
|
||||||
|
|
||||||
virtual ~MultiHeadAttention();
|
virtual ~MultiHeadAttention();
|
||||||
|
|
||||||
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
|
void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);
|
||||||
|
|
||||||
void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
|
void Backward(const T *grad_output_ptr, const T *input_ptr,
|
||||||
const T *input_mask_ptr, T *grad_input_ptr);
|
const T *output_ptr, const T *input_mask_ptr,
|
||||||
|
T *grad_input_ptr);
|
||||||
|
|
||||||
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer);
|
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
|
||||||
|
T *buffer);
|
||||||
|
|
||||||
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
|
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
|
||||||
const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer);
|
const T *output_ptr, const T *grad_output_ptr,
|
||||||
|
T *grad_input_attn_layer_bwptr, T *buffer);
|
||||||
|
|
||||||
void set_cur_batch_shape(int batch_size, int seq_len) {
|
void set_cur_batch_shape(int batch_size, int seq_len) {
|
||||||
_batch_size = batch_size;
|
_batch_size = batch_size;
|
||||||
@ -83,14 +87,17 @@ class MultiHeadAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
|
_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
|
||||||
_soft_out_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
_soft_out_ptr =
|
||||||
_ctx_bufB_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||||
|
_ctx_bufB_ptr =
|
||||||
|
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||||
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
|
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);
|
||||||
|
|
||||||
// buffer size needed by attn bw
|
// buffer size needed by attn bw
|
||||||
size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size +
|
size_t smem_size =
|
||||||
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
|
4 * _max_batch_tokens * _hidden_size / pg_size +
|
||||||
_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
|
||||||
|
_max_batch_tokens * _heads / pg_size * _max_seq_len);
|
||||||
|
|
||||||
if (!_shared_mem_ptr) {
|
if (!_shared_mem_ptr) {
|
||||||
cuda_free(_shared_mem_ptr);
|
cuda_free(_shared_mem_ptr);
|
||||||
|
Loading…
Reference in New Issue
Block a user