diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index 7fc9d1553..ead4be8b7 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -87,7 +87,7 @@ class PagedAttention: Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size] """ bsz = len(seq_lengths) - padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size) + padded_tensor = torch.zeros(bsz, max_seq_len, num_heads, head_size, dtype=tensor.dtype) token_idx = 0 for i, seq_len in enumerate(seq_lengths): diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index 64efa3491..343c0a9ff 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -5,6 +5,8 @@ # # Inspired and modified from Triton Tutorial - Fused Attention # https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html +from typing import Optional + import torch import triton import triton.language as tl @@ -190,13 +192,8 @@ def context_attention_unpadded( context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], block_size: int, + max_seq_len_in_b: Optional[int] = None, ): - # q/k in context stage are supposed to be put into k_cache and v_cache. - # This step can be optimized in future. - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk == Lv assert Lk in {32, 64, 128, 256} @@ -210,7 +207,7 @@ def context_attention_unpadded( num_kv_group = num_heads // num_kv_heads num_seqs, max_blocks_per_seq = block_tables.shape - max_seq_len = context_lengths.max().item() + max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b sm_scale = 1.0 / (Lq**0.5) output = torch.zeros_like(q) @@ -220,7 +217,7 @@ def context_attention_unpadded( assert block_size in {16, 32, 64, 128} BLOCK_M = BLOCK_N = block_size - grid = (num_seqs, num_heads, triton.cdiv(max_seq_len, BLOCK_M)) + grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M)) _fwd_context_paged_attention_kernel[grid]( q, diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index fec12f604..25cdea399 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -215,10 +215,9 @@ def flash_decoding_attention( Returns: Output tensor with shape [bsz, num_heads, q_len, head_dim] """ - if q.dim() == 3: - bsz, num_heads, head_dim = q.shape - else: - raise ValueError(f"The query dim should be 3, but got {q.dim()}.") + q = q.squeeze() if q.dim() == 4 else q + assert q.dim() == 3, f"Incompatible q dim: {q.dim()}" + bsz, num_heads, head_dim = q.shape assert head_dim in {32, 64, 128, 256} assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( diff --git a/tests/test_infer_ops/triton/test_context_attn_unpad.py b/tests/test_infer_ops/triton/test_context_attn_unpad.py index eb71cbed2..4498b8519 100644 --- a/tests/test_infer_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer_ops/triton/test_context_attn_unpad.py @@ -1,7 +1,9 @@ import pytest import torch from packaging import version +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables, torch_attn_ref @@ -89,6 +91,7 @@ def test_context_attention( qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM) qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) + q_unpad = q_unpad.contiguous() k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables( k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device @@ -109,5 +112,103 @@ def test_context_attention( assert torch.equal(v_cache_ref, v_cache_triton) +BATCH = 16 +BLOCK_SIZE = 32 +SAME_LEN = True +WARM_UPS = 10 +REPS = 100 +configs = [ + triton.testing.Benchmark( + x_names=["KV_LEN"], + x_vals=[2**i for i in range(8, 13)], + # x_vals=[x for x in range(256, 8192, 256)], + line_arg="provider", + line_vals=["torch", "triton"], + line_names=["Torch", "Triton"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"context_attn-block_size-{BLOCK_SIZE}-batch{BATCH}", + args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, + ) +] + + +@triton.testing.perf_report(configs) +def bench_kernel( + bsz, + KV_LEN, + provider, + block_size: int, + kv_group_num: int, + same_context_len: bool, +): + num_attn_heads = 16 + max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size) + max_seq_len = block_size * max_num_blocks_per_seq + + num_kv_heads = num_attn_heads // kv_group_num + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + block_size * max_num_blocks_per_seq + dtype = torch.float16 + device = get_current_device() + + if same_context_len: + context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + else: + context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) + num_tokens = torch.sum(context_lengths).item() + + qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM) + qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) + q_unpad = q_unpad.contiguous() + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables( + k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + block_tables = block_tables.to(device=device) + + quantiles = [0.5, 0.2, 0.8] + if provider == "torch": + q_padded = PagedAttention.pad_and_reshape(q_unpad, context_lengths, max_seq_len, num_attn_heads, HEAD_DIM) + k_padded = PagedAttention.pad_and_reshape(k_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM) + v_padded = PagedAttention.pad_and_reshape(v_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM) + q_padded, k_padded, v_padded = ( + q_padded.to(device=device), + k_padded.to(device=device), + v_padded.to(device=device), + ) + q_padded = q_padded.transpose(1, 2) + k_padded = PagedAttention.repeat_kv(k_padded.transpose(1, 2), kv_group_num) + v_padded = PagedAttention.repeat_kv(v_padded.transpose(1, 2), kv_group_num) + # This benchmark ignores the padding mask. *Only* use the-same-length inputs for benchmarkings + attn_mask = AttentionMaskConverter._make_causal_mask( + (bsz, max_seq_len), q_padded.dtype, q_padded.device, past_key_values_length=0 + ) + attn_mask = attn_mask.to(device=q_padded.device) + fn = lambda: torch_attn_ref( + q_padded, + k_padded, + v_padded, + attn_mask, + bsz, + max_seq_len, + max_seq_len, + num_attn_heads, + num_kv_heads, + HEAD_DIM, + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + if provider == "triton": + k_cache_triton = torch.zeros_like(k_cache_ref) + v_cache_triton = torch.zeros_like(v_cache_ref) + fn = lambda: context_attention_unpadded( + q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + + return ms, min_ms, max_ms + + if __name__ == "__main__": test_context_attention(4, 32, 8, 16, 1, True) + # bench_kernel.run(save_path=".", print_data=True) diff --git a/tests/test_infer_ops/triton/test_decoding_attn.py b/tests/test_infer_ops/triton/test_decoding_attn.py index e93e072af..063ae2814 100644 --- a/tests/test_infer_ops/triton/test_decoding_attn.py +++ b/tests/test_infer_ops/triton/test_decoding_attn.py @@ -97,7 +97,9 @@ def test_flash_decoding( mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) sm_scale = 1.0 / (HEAD_DIM**0.5) out_triton = flash_decoding_attention( - q, + # Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), + # refer to attention forward in modeling. + q.squeeze(2), k_cache, v_cache, kv_seq_lengths, @@ -188,7 +190,9 @@ def bench_kernel( mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) sm_scale = 1.0 / (HEAD_DIM**0.5) fn = lambda: flash_decoding_attention( - q, + # Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1), + # refer to attention forward in modeling. + q.squeeze(2), k_cache, v_cache, kv_lengths,