diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index 343c0a9ff..e31d9e5da 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -22,6 +22,7 @@ def _fwd_context_paged_attention_kernel( KCache, VCache, BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence] + batch_size, stride_qt, stride_qh, stride_qd, @@ -49,6 +50,8 @@ def _fwd_context_paged_attention_kernel( BLOCK_N: tl.constexpr, ): cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return cur_head_idx = tl.program_id(1) block_start_m = tl.program_id(2) # Br, max_input_len // Block_M cur_kv_head_idx = cur_head_idx // KV_GROUPS @@ -217,6 +220,8 @@ def context_attention_unpadded( assert block_size in {16, 32, 64, 128} BLOCK_M = BLOCK_N = block_size + # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton + # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M)) _fwd_context_paged_attention_kernel[grid]( @@ -227,6 +232,7 @@ def context_attention_unpadded( k_cache, v_cache, block_tables, + num_seqs, q.stride(0), q.stride(1), q.stride(2), diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 25cdea399..0a42a2f13 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -16,6 +16,7 @@ def _flash_decoding_fwd_kernel( mid_o, # [batch_size, head_num, kv_split_num, head_dim] mid_o_lse, # [batch_size, head_num, kv_split_num] kv_seq_len, # [batch_size] + batch_size, stride_qt, stride_qh, stride_qd, @@ -39,6 +40,8 @@ def _flash_decoding_fwd_kernel( HEAD_DIM: tl.constexpr, ): cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return cur_head_idx = tl.program_id(1) block_start_kv = tl.program_id(2) # for splitting k/v @@ -132,6 +135,7 @@ def _flash_decoding_fwd_reduce_kernel( mid_o_lse, # [batch_size, head_num, kv_split_num] O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim] kv_seq_len, + batch_size, stride_mid_ot, stride_mid_oh, stride_mid_ob, @@ -147,6 +151,8 @@ def _flash_decoding_fwd_reduce_kernel( HEAD_DIM: tl.constexpr, ): cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return cur_head_idx = tl.program_id(1) cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) @@ -251,6 +257,8 @@ def flash_decoding_attention( else mid_output_lse ) + # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton + # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV)) _flash_decoding_fwd_kernel[grid]( q, @@ -260,6 +268,7 @@ def flash_decoding_attention( mid_output, mid_output_lse, kv_seq_len, + bsz, q.stride(0), q.stride(1), q.stride(2), @@ -285,12 +294,14 @@ def flash_decoding_attention( output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped - grid = (bsz, num_heads) + grid = (triton.next_power_of_2(bsz), num_heads) + _flash_decoding_fwd_reduce_kernel[grid]( mid_output, mid_output_lse, output, kv_seq_len, + bsz, mid_output.stride(0), mid_output.stride(1), mid_output.stride(2),