diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index f3cfb3860..09e95070a 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -6,7 +6,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.inference.struct import BatchInfo -from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd +from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention from colossalai.logging import get_dist_logger from flash_attn.bert_padding import index_first_axis, pad_input # noqa @@ -209,7 +209,15 @@ def llama_attn_forward( if HAS_TRITON: copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) - attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size) + # TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel + # in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output + # should be revised, as we could see in previous part of `llama_attn_forward` we still have + # redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent. + query_states = query_states.transpose(1, 2) + attn_output = flash_decoding_attention( + query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size + ) + attn_output = attn_output.squeeze(1) else: attn_output = PagedAttention.pad_decoding_forward( query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 763522453..b814b142b 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -9,7 +9,9 @@ except ImportError: # There may exist import error even if we have triton installed. if HAS_TRITON: from .context_attn_unpad import context_attention_unpadded - from .flash_decoding import flash_decoding_fwd + from .flash_decoding import flash_decoding_attention + from .flash_decoding_utils import FDIntermTensors + from .rms_layernorm import rms_layernorm from .gptq_triton import gptq_fused_linear_triton from .kvcache_copy import copy_kv_to_blocked_cache @@ -18,10 +20,11 @@ if HAS_TRITON: __all__ = [ "context_attention_unpadded", - "flash_decoding_fwd", + "flash_decoding_attention", "copy_kv_to_blocked_cache", "softmax", "rms_layernorm", "gptq_fused_linear_triton", "rotary_embedding", + "FDIntermTensors", ] diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index ed1629e96..15f1921ca 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -9,15 +9,16 @@ import triton.language as tl # Triton 2.1.0 @triton.jit def _flash_decoding_fwd_kernel( - Q, # [batch_size, head_num, head_dim] + Q, # [batch_size, head_num, q_len(1), head_dim] KCache, # [num_blocks, num_kv_heads, head_dim, block_size] VCache, # [num_blocks, num_kv_heads, head_dim, block_size] block_tables, # [batch_size, max_blocks_per_sequence] mid_o, # [batch_size, head_num, kv_split_num, head_dim] mid_o_lse, # [batch_size, head_num, kv_split_num] - context_lengths, # [batch_size] + kv_seq_len, # [batch_size] stride_qt, stride_qh, + stride_ql, stride_qd, stride_cacheb, stride_cacheh, @@ -51,7 +52,7 @@ def _flash_decoding_fwd_kernel( tl.static_assert(BLOCK_KV == BLOCK_SIZE) # get the current (kv) sequence length from provided context lengths tensor - cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd q = tl.load(Q + offsets_q) @@ -65,7 +66,6 @@ def _flash_decoding_fwd_kernel( cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) if block_start_kv * BLOCK_KV >= cur_kv_seq_len: - # TODO might want to remove if-else block? return cur_occupied_size = tl.where( @@ -132,7 +132,7 @@ def _flash_decoding_fwd_reduce_kernel( mid_o, # [batch_size, head_num, kv_split_num, head_dim] mid_o_lse, # [batch_size, head_num, kv_split_num] O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim] - context_lengths, + kv_seq_len, stride_mid_ot, stride_mid_oh, stride_mid_ob, @@ -141,6 +141,7 @@ def _flash_decoding_fwd_reduce_kernel( stride_o_lseh, stride_o_lseb, stride_ob, + stride_ol, stride_oh, stride_od, BLOCK_KV: tl.constexpr, @@ -149,7 +150,7 @@ def _flash_decoding_fwd_reduce_kernel( cur_seq_idx = tl.program_id(0) cur_head_idx = tl.program_id(1) - cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) offsets_dmodel = tl.arange(0, HEAD_DIM) # NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have @@ -181,21 +182,46 @@ def _flash_decoding_fwd_reduce_kernel( # Decoding Stage # Used with blocked KV Cache (PagedAttention) -def flash_decoding_fwd( - q: torch.Tensor, # [bsz(e.g.num_tokens), 1, num_heads, head_dim] - k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size] - v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size] - context_lengths: torch.Tensor, # [batch_size] - block_tables: torch.Tensor, # [batch_size, max_blocks_per_sequence] +def flash_decoding_attention( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_seq_len: torch.Tensor, + block_tables: torch.Tensor, block_size: int, - num_kv_group: int = 1, + max_seq_len_in_batch: int = None, + mid_output: torch.Tensor = None, + mid_output_lse: torch.Tensor = None, + sm_scale: int = None, + kv_group_num: int = 1, ): - bsz, _, num_heads, head_dim = q.shape + """ + Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. + + Args: + q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim] + k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] + v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] + kv_seq_len (torch.Tensor): [batch_size] + records the (kv) sequence lengths incorporating past kv sequence lengths. + block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] + max_seq_len_in_batch (int): Maximum sequence length in the batch. + mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] + Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. + mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] + Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`. + block_size (int): Size of each block in the blocked key/value cache. + num_kv_group (int, optional): Number of key/value groups. Defaults to 1. + + Returns: + Output tensor with shape [bsz, num_heads, q_len, head_dim] + """ + bsz, num_heads, _, head_dim = q.shape assert head_dim in {32, 64, 128, 256} - assert context_lengths.shape[0] == block_tables.shape[0] == bsz, ( + assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, ( f"Got incompatible batch size (number of seqs):\n" - f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, " + f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, " f"batch size {bsz}" ) assert k_cache.size(-1) == v_cache.size(-1) == block_size, ( @@ -203,75 +229,79 @@ def flash_decoding_fwd( f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, " f"v_cache block_size {v_cache.size(-1)}" ) - # NOTE `context_lengths` records the (kv) sequence lengths incorporating past kv sequence lengths. - bsz = context_lengths.size(0) # e.g. the number of seqs - max_seq_len = context_lengths.max().item() - sm_scale = 1.0 / (head_dim**0.5) # NOTE BLOCK_KV could be considered as block splitting the sequence on k/v # For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`) assert block_size in {16, 32, 64, 128} BLOCK_KV = block_size - kv_max_split_num = (max_seq_len + BLOCK_KV - 1) // BLOCK_KV - mid_o = torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device) - mid_o_lse = torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) + sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale + max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch + # For compatibility (TODO revise modeling in future) + kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV + mid_output = ( + torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device) + if mid_output is None + else mid_output + ) + mid_output_lse = ( + torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device) + if mid_output_lse is None + else mid_output_lse + ) - if q.dim() == 4: - assert q.size(1) == 1, f"q_len is supposed to be 1 but is {q.size(1)}" - q = q.squeeze(1) - - grid = (bsz, num_heads, triton.cdiv(max_seq_len, BLOCK_KV)) + grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV)) _flash_decoding_fwd_kernel[grid]( q, k_cache, v_cache, block_tables, - mid_o, - mid_o_lse, - context_lengths, + mid_output, + mid_output_lse, + kv_seq_len, q.stride(0), q.stride(1), q.stride(2), + q.stride(3), k_cache.stride(0), k_cache.stride(1), k_cache.stride(2), k_cache.stride(3), block_tables.stride(0), block_tables.stride(1), - mid_o.stride(0), - mid_o.stride(1), - mid_o.stride(2), - mid_o.stride(3), - mid_o_lse.stride(0), - mid_o_lse.stride(1), - mid_o_lse.stride(2), + mid_output.stride(0), + mid_output.stride(1), + mid_output.stride(2), + mid_output.stride(3), + mid_output_lse.stride(0), + mid_output_lse.stride(1), + mid_output_lse.stride(2), sm_scale, - KV_GROUPS=num_kv_group, + KV_GROUPS=kv_group_num, BLOCK_KV=block_size, BLOCK_SIZE=block_size, HEAD_DIM=head_dim, ) - output = torch.zeros_like(q) - output = output.view(-1, output.size(-2), output.size(-1)) + output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped grid = (bsz, num_heads) _flash_decoding_fwd_reduce_kernel[grid]( - mid_o, - mid_o_lse, + mid_output, + mid_output_lse, output, - context_lengths, - mid_o.stride(0), - mid_o.stride(1), - mid_o.stride(2), - mid_o.stride(3), - mid_o_lse.stride(0), - mid_o_lse.stride(1), - mid_o_lse.stride(2), + kv_seq_len, + mid_output.stride(0), + mid_output.stride(1), + mid_output.stride(2), + mid_output.stride(3), + mid_output_lse.stride(0), + mid_output_lse.stride(1), + mid_output_lse.stride(2), output.stride(0), output.stride(1), output.stride(2), + output.stride(3), BLOCK_KV=block_size, HEAD_DIM=head_dim, ) diff --git a/colossalai/kernel/triton/flash_decoding_utils.py b/colossalai/kernel/triton/flash_decoding_utils.py new file mode 100644 index 000000000..a91524815 --- /dev/null +++ b/colossalai/kernel/triton/flash_decoding_utils.py @@ -0,0 +1,58 @@ +import torch + +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.utils import get_current_device + + +class FDIntermTensors(metaclass=SingletonMeta): + """Singleton class to hold tensors used for storing intermediate values in flash-decoding. + For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv) + """ + + def __init__(self): + self._tensors_initialized = False + + @property + def is_initialized(self): + return self._tensors_initialized + + @property + def mid_output(self): + assert self.is_initialized, "Intermediate tensors not initialized yet" + return self._mid_output + + @property + def mid_output_lse(self): + assert self.is_initialized, "Intermediate tensors not initialized yet" + return self._mid_output_lse + + def initialize( + self, + max_batch_size: int, + num_attn_heads: int, + kv_max_split_num: int, + head_dim: int, + dtype: torch.dtype = torch.float32, + device: torch.device = get_current_device(), + ) -> None: + """Initialize tensors. + + Args: + max_batch_size (int): The maximum batch size over all the model forward. + This could be greater than the batch size in attention forward func when using dynamic batch size. + num_attn_heads (int)): Number of attention heads. + kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm. + **The maximum length/size of blocks splitted on kv should be the kv cache block size.** + head_dim (int): Head dimension. + dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors. + device (torch.device, optional): Device used to initialize intermediate tensors. + """ + assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized." + + self._mid_output = torch.empty( + size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device + ) + self._mid_output_lse = torch.empty( + size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device + ) + self._tensors_initialized = True diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py index 3cd897931..31bd4812a 100644 --- a/tests/test_infer_ops/triton/kernel_utils.py +++ b/tests/test_infer_ops/triton/kernel_utils.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch from torch.nn import functional as F @@ -17,13 +19,22 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim) +def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, kv_seq_len: int, device="cuda"): + padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=device) + for i in range(bsz): + cur_seq_len = kv_lengths[i].item() + assert cur_seq_len <= kv_seq_len + padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf") + return padding_mask + + # Attention calculation adapted from HuggingFace transformers repository # src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350 def torch_attn_ref( - q: torch.Tensor, # [bsz, seq_len, num_heads, head_dim] - k: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim] - v: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim] + q: torch.Tensor, # [bsz, num_heads, q_len, head_dim] + k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim] + v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim] attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len] bsz: int, seq_len: int, @@ -31,14 +42,8 @@ def torch_attn_ref( num_heads: int, num_kv_heads: int, head_dim: int, -): +) -> torch.Tensor: assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim - q = q.view(bsz, seq_len, num_heads, head_dim) - k = k.view(bsz, kv_seq_len, num_kv_heads, head_dim) - v = v.view(bsz, kv_seq_len, num_kv_heads, head_dim) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) # repeat kv for GQA and MQA # k/v won't change if kv_group_num is 1 @@ -49,7 +54,6 @@ def torch_attn_ref( qk = torch.matmul(q, k.transpose(2, 3)) attn_scores = qk / (head_dim**0.5) - assert attn_scores.shape == (bsz, num_heads, seq_len, kv_seq_len), "Invalid shape of attention scores" # for left-side padding if attention_mask.size() != (bsz, 1, seq_len, kv_seq_len): @@ -77,7 +81,7 @@ def mock_alloc_block_table_and_kvcache( num_seqs: int, max_num_blocks_per_seq: int, block_size: int, -): +) -> torch.Tensor: """Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache.""" block_id = 0 block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32) @@ -102,12 +106,10 @@ def mock_alloc_block_table_and_kvcache( return block_tables -def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int): - """Allocate 1 token on the block table for each seqs in block tables. - It won't change provided context_lengths - """ - - # consider max_block_id as the last physical block allocated +def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None: + # Allocate 1 token on the block table for each seqs in block tables. + # It won't change provided context_lengths. + # Consider max_block_id as the last physical block allocated # NOTE It assumes all the blocks preceding this block have been allocated max_block_id = torch.max(block_tables).item() # the indices on each block table representing the cache block to be allocated one more token @@ -126,3 +128,36 @@ def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.T if new_block_ids.numel(): new_block_alloc_local_indices = alloc_local_block_indices[require_new_block] block_tables[require_new_block, new_block_alloc_local_indices] = new_block_ids + + +def generate_caches_and_block_tables( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda" +) -> Tuple[torch.Tensor, ...]: + # Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths + # k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) + # Mock allocation on block tables as well as blocked kv caches + block_tables = mock_alloc_block_table_and_kvcache( + k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size + ) + return k_cache, v_cache, block_tables + + +def convert_kv_unpad_to_padded( + k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int +) -> torch.Tensor: + # Rebuild (batched) k/v with padding to be used by torch attention + # input k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim] + # returns k/v padded [bsz, num_kv_heads, max_seq_len, head_dim] + _, num_kv_heads, head_dim = k_unpad.shape + k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k_unpad.dtype, device=k_unpad.device) + prev_len_sum = 0 + for i, seq_len in enumerate(kv_seq_lengths.tolist()): + # left-side padding + k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len] + prev_len_sum += seq_len + k_torch = k_torch.transpose(1, 2) + return k_torch diff --git a/tests/test_infer_ops/triton/test_context_attn_unpad.py b/tests/test_infer_ops/triton/test_context_attn_unpad.py index 60459a3c2..eb71cbed2 100644 --- a/tests/test_infer_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer_ops/triton/test_context_attn_unpad.py @@ -4,7 +4,7 @@ from packaging import version from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device -from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, torch_attn_ref +from tests.test_infer_ops.triton.kernel_utils import generate_caches_and_block_tables, torch_attn_ref try: import triton # noqa @@ -16,6 +16,8 @@ except ImportError: TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +HEAD_DIM = 32 + def torch_attn_unpad( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor, num_heads: int, num_kv_heads: int @@ -34,9 +36,9 @@ def torch_attn_unpad( mask[mask == 0.0] = float("-inf") torch_attn_ref_out = torch_attn_ref( - q[start_idx:end_idx].unsqueeze(0), - k[start_idx:end_idx].unsqueeze(0), - v[start_idx:end_idx].unsqueeze(0), + q[start_idx:end_idx].unsqueeze(0).transpose(1, 2), + k[start_idx:end_idx].unsqueeze(0).transpose(1, 2), + v[start_idx:end_idx].unsqueeze(0).transpose(1, 2), mask, 1, # set bsz as 1 as we're processing sequence one by one seq_len, @@ -74,7 +76,6 @@ def test_context_attention( num_kv_heads = num_attn_heads // kv_group_num assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." - head_dim = 32 max_seq_len = max_num_blocks_per_seq * block_size dtype = torch.float16 device = get_current_device() @@ -85,28 +86,28 @@ def test_context_attention( context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) num_tokens = torch.sum(context_lengths).item() - qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, head_dim) - qkv = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) - q, k, v = torch.split(qkv, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) + qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM) + qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2) - cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) - k_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device) - k_cache_triton = torch.zeros_like(k_cache_torch) - v_cache_torch = torch.zeros(size=cache_shape, dtype=dtype, device=device) - v_cache_triton = torch.zeros_like(v_cache_torch) - - # Mock allocation on block tables - block_tables = mock_alloc_block_table_and_kvcache( - k, v, k_cache_torch, v_cache_torch, context_lengths, bsz, max_num_blocks_per_seq, block_size + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables( + k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) block_tables = block_tables.to(device=device) + k_cache_triton = torch.zeros_like(k_cache_ref) + v_cache_triton = torch.zeros_like(v_cache_ref) + out_triton = context_attention_unpadded( - q, k, v, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size + q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size ) - out_torch = torch_attn_unpad(q, k, v, context_lengths, num_attn_heads, num_kv_heads) + out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads) assert out_torch.shape == out_triton.shape - assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) - assert torch.allclose(k_cache_torch, k_cache_triton) - assert torch.allclose(v_cache_torch, v_cache_triton) + assert torch.allclose(out_torch, out_triton, atol=1e-3) + assert torch.equal(k_cache_ref, k_cache_triton) + assert torch.equal(v_cache_ref, v_cache_triton) + + +if __name__ == "__main__": + test_context_attention(4, 32, 8, 16, 1, True) diff --git a/tests/test_infer_ops/triton/test_decoding_attn.py b/tests/test_infer_ops/triton/test_decoding_attn.py index 58b8fe0cd..e93e072af 100644 --- a/tests/test_infer_ops/triton/test_decoding_attn.py +++ b/tests/test_infer_ops/triton/test_decoding_attn.py @@ -2,9 +2,14 @@ import pytest import torch from packaging import version -from colossalai.kernel.triton import flash_decoding_fwd +from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device -from tests.test_infer_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache, torch_attn_ref +from tests.test_infer_ops.triton.kernel_utils import ( + convert_kv_unpad_to_padded, + generate_caches_and_block_tables, + prepare_padding_mask, + torch_attn_ref, +) try: import triton # noqa @@ -16,23 +21,37 @@ except ImportError: TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +Q_LEN = 1 +HEAD_DIM = 128 -def torch_decoding(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor): - assert context_lengths.dim() == 1, "context_lengths should be a 1D tensor" - assert q.size(1) == 1, "Only used for decoding" - assert k.shape == v.shape - bsz, _, num_heads, head_dim = q.shape - _, kv_seq_len, num_kv_heads, _ = k.shape - assert num_heads % num_kv_heads == 0, "Invalid kv heads and attention heads." - padding_mask = torch.zeros((bsz, 1, 1, kv_seq_len), dtype=torch.float32, device=q.device) - for i in range(bsz): - cur_seq_len = context_lengths[i].item() - assert cur_seq_len <= kv_seq_len - padding_mask[i, :, :, : kv_seq_len - cur_seq_len] = float("-inf") +def prepare_data( + bsz: int, + num_attn_heads: int, + num_kv_heads: int, + head_dim: int, + same_context_len: bool, + q_len: int, + max_kv_seq_len: int, + dtype=torch.float16, + device="cuda", +): + # Use the provided maximum sequence length for each sequence when testing with teh same context length, + # otherwise generate random context lengths. + kv_lengths = ( + torch.tensor([max_kv_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + if same_context_len + else torch.randint(low=1, high=max_kv_seq_len, size=(bsz,), dtype=torch.int32, device=device) + ) + num_tokens = torch.sum(kv_lengths).item() - out = torch_attn_ref(q, k, v, padding_mask, bsz, 1, kv_seq_len, num_heads, num_kv_heads, head_dim) - return out + q_size = (bsz, q_len, num_attn_heads, head_dim) + q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2) + kv_size = (num_tokens, 2 * num_kv_heads, head_dim) + kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2) + + return q, k_unpad, v_unpad, kv_lengths @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") @@ -57,59 +76,135 @@ def test_flash_decoding( num_kv_heads = num_attn_heads // kv_group_num assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." - q_len = 1 - head_dim = 128 max_seq_len = block_size * max_num_blocks_per_seq dtype = torch.float16 device = get_current_device() - if same_context_len: - context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) - else: - context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) - num_tokens = torch.sum(context_lengths).item() - - q_size = (bsz, q_len, num_attn_heads, head_dim) - q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) - kv_size = (num_tokens, 2 * num_kv_heads, head_dim) - kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) - k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2) - - cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size) - k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) - v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) - # Mock allocation on block tables as well as blocked kv caches - block_tables = mock_alloc_block_table_and_kvcache( - k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size + q, k_unpad, v_unpad, kv_seq_lengths = prepare_data( + bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device + ) + k_cache, v_cache, block_tables = generate_caches_and_block_tables( + k_unpad, v_unpad, kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device ) block_tables = block_tables.to(device=device) - - q = q.view(bsz, q_len, num_attn_heads, head_dim) - out_triton = flash_decoding_fwd( + # The maximum sequence length in the batch (if context lengths randomly generated) + max_seq_len_in_b = kv_seq_lengths.max().item() + # The maximum block length splitted on kv should be the kv cache block size + kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size + mid_output = torch.empty( + size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device + ) + mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) + sm_scale = 1.0 / (HEAD_DIM**0.5) + out_triton = flash_decoding_attention( q, k_cache, v_cache, - context_lengths, + kv_seq_lengths, block_tables, block_size, - kv_group_num, - ) - out_triton = out_triton.unsqueeze(1) # [bsz, 1, num_heads, head_dim] + max_seq_len_in_b, + mid_output, + mid_output_lse, + sm_scale=sm_scale, + kv_group_num=kv_group_num, + ) # [bsz, 1, num_heads, head_dim] - # rebuild (batched) kv with padding for torch attention - # q [bsz, 1, num_heads, head_dim] - # k/v [num_tokens, num_kv_heads, head_dim] - max_seq_len = context_lengths.max().item() - k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k.dtype, device=k.device) - v_torch = torch.zeros_like(k_torch) - prev_len_sum = 0 - for i, seq_len in enumerate(context_lengths.tolist()): - # mock left-side padding - k_torch[i, -seq_len:, :, :] = k[prev_len_sum : prev_len_sum + seq_len] - v_torch[i, -seq_len:, :, :] = v[prev_len_sum : prev_len_sum + seq_len] - prev_len_sum += seq_len - # k/v [bsz, max_seq_len, num_kv_heads, head_dim] - out_torch = torch_decoding(q, k_torch, v_torch, context_lengths) + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, bsz, max_seq_len_in_b) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, bsz, max_seq_len_in_b) + torch_padding_mask = prepare_padding_mask(kv_seq_lengths, bsz, max_seq_len_in_b, q.device) + out_torch = torch_attn_ref( + q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM + ) assert out_torch.shape == out_triton.shape assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) + + +BATCH = 16 +BLOCK_SIZE = 32 +SAME_LEN = True +WARM_UPS = 10 +REPS = 100 +configs = [ + triton.testing.Benchmark( + x_names=["KV_LEN"], + x_vals=[2**i for i in range(8, 14)], + # x_vals=[x for x in range(256, 8192, 256)], + line_arg="provider", + line_vals=["torch", "triton"], + line_names=["Torch", "Triton"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}", + args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1}, + ) +] + + +@triton.testing.perf_report(configs) +def bench_kernel( + bsz, + KV_LEN, + provider, + block_size: int, + kv_group_num: int, + same_context_len: bool, +): + num_attn_heads = 16 + max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size) + max_seq_len = block_size * max_num_blocks_per_seq + + num_kv_heads = num_attn_heads // kv_group_num + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + block_size * max_num_blocks_per_seq + dtype = torch.float16 + device = get_current_device() + + q, k_unpad, v_unpad, kv_lengths = prepare_data( + bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device + ) + max_seq_len_in_b = kv_lengths.max().item() # for random lengths + + quantiles = [0.5, 0.2, 0.8] + if provider == "torch": + k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b) + v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b) + torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device) + fn = lambda: torch_attn_ref( + q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM + ) + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + if provider == "triton": + k_cache, v_cache, block_tables = generate_caches_and_block_tables( + k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + block_tables = block_tables.to(device=device) + # the maximum block length splitted on kv should be the kv cache block size + kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size + mid_output = torch.empty( + size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device + ) + mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device) + sm_scale = 1.0 / (HEAD_DIM**0.5) + fn = lambda: flash_decoding_attention( + q, + k_cache, + v_cache, + kv_lengths, + block_tables, + block_size, + max_seq_len_in_b, + mid_output, + mid_output_lse, + sm_scale=sm_scale, + kv_group_num=kv_group_num, + ) # [bsz, 1, num_heads, head_dim] + ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles) + + return ms, min_ms, max_ms + + +if __name__ == "__main__": + test_flash_decoding(16, 32, 32, 16, 1, True) + # bench_kernel.run(save_path=".", print_data=True)