diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 44ce381a4..355140bc1 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -282,10 +282,11 @@ class NopadLlamaAttention(LlamaAttention): torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) ) + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + block_size = k_cache.size(-2) if is_prompts: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) attn_output = context_attention_unpadded( q=query_states, k=key_states, @@ -300,7 +301,7 @@ class NopadLlamaAttention(LlamaAttention): sm_scale=sm_scale, ) else: - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], k_cache, block_tables, sequence_lengths) + copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) attn_output = flash_decoding_attention( q=query_states, diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index 8e31b42a8..1aaeb6830 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -75,6 +75,7 @@ def copy_kv_to_blocked_cache( block_size = k_cache.size(-2) num_warps = 8 if head_dim > 128 else 4 + grid = (bsz, num_kv_heads) _copy_to_kvcache_seqlen1_kernel[grid]( k, diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 7a38c0fc8..9194319d5 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -222,11 +222,11 @@ def fused_rotary_embedding_kernel( out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim - past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1 + past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1 last_block_idx = past_kv_seq_len // block_size block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride - block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens)) + block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride) offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride kv_range0 = ( @@ -274,122 +274,6 @@ def fused_rotary_embedding_kernel( ) -@triton.jit -def fused_rotary_embedding_kernel_v2( - q, - k, - cos, - sin, - kv_cache, - BLOCK_TABLES, - context_lengths, - q_token_stride, - q_head_stride, - k_token_stride, - k_head_stride, - head_dim_stride, - cos_token_stride, - cos_stride, - cacheb_stride, - cacheh_stride, - cachebs_stride, - cached_stride, - bts_stride, - btb_stride, - block_size, - q_total_tokens, - Q_HEAD_NUM: tl.constexpr, - K_HEAD_NUM: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - block_head_index = tl.program_id(0) - if block_head_index >= Q_HEAD_NUM: - return - block_token_index = tl.program_id(1) - - dim_range0 = tl.arange(0, HEAD_DIM // 2) - dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - - off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride - off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride - off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride - off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride - - loaded_q0 = tl.load( - q + off_q0, - ) - loaded_q1 = tl.load( - q + off_q1, - ) - - loaded_k0 = tl.load( - k + off_k0, - ) - - loaded_k1 = tl.load( - k + off_k1, - ) - - off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride - - loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0) - loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0) - - out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin - out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos - - out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin - out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim - - past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1 - - last_block_idx = past_kv_seq_len // block_size - block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride - block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens)) - offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride - - kv_range0 = ( - block_ids * cacheb_stride - + block_head_index * cacheh_stride - + offsets_in_last_block - + dim_range0 * cached_stride - ) - kv_range1 = ( - block_ids * cacheb_stride - + block_head_index * cacheh_stride - + offsets_in_last_block - + dim_range1 * cached_stride - ) - - tl.store( - kv_cache + kv_range0, - out_k0, - ) - tl.store( - kv_cache + kv_range1, - out_k1, - ) - - # concat - tl.store( - q + off_q0, - out_q0, - ) - tl.store( - q + off_q1, - out_q1, - ) - tl.store( - k + off_k0, - out_k0, - ) - tl.store( - k + off_k1, - out_k1, - ) - - -@torch.no_grad() def rotary_embedding( q: torch.Tensor, k: torch.Tensor, @@ -413,13 +297,12 @@ def rotary_embedding( assert q.size(0) == k.size(0) BLOCK_HEAD = 4 BLOCK_TOKENS = 4 + grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"])) - if head_dim >= 1024: + if head_dim >= 256: num_warps = 32 - elif head_dim >= 512: + elif head_dim >= 128: num_warps = 16 - elif head_dim >= 256: - num_warps = 8 else: num_warps = 4 @@ -435,10 +318,6 @@ def rotary_embedding( cos_token_stride = cos.stride(0) cos_stride = cos.stride(1) if k_cache == None: - grid = lambda META: ( - triton.cdiv(q_head_num, META["BLOCK_HEAD"]), - triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]), - ) rotary_embedding_kernel[grid]( q, k, @@ -460,8 +339,7 @@ def rotary_embedding( num_warps=num_warps, ) else: - grid = (triton.next_power_of_2(q_head_num), q_total_tokens) - fused_rotary_embedding_kernel_v2[grid]( + fused_rotary_embedding_kernel[grid]( q, k, cos, @@ -487,6 +365,8 @@ def rotary_embedding( Q_HEAD_NUM=q_head_num, K_HEAD_NUM=k_head_num, HEAD_DIM=head_dim, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_TOKENS=BLOCK_TOKENS, num_warps=num_warps, ) return diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh index a8619bce9..2a6e5a5d7 100755 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -1,5 +1,4 @@ ROOT=$(realpath $(dirname $0)) -echo $ROOT PY_SCRIPT=${ROOT}/benchmark_llama.py GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) mode=$1 diff --git a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py index e4f4bb282..6a8dc85f0 100644 --- a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py @@ -3,7 +3,7 @@ import torch from packaging import version from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb -from colossalai.kernel.triton import copy_kv_to_blocked_cache, rotary_embedding +from colossalai.kernel.triton import rotary_embedding from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 try: @@ -94,8 +94,8 @@ configs = [ x_names=["num_tokens"], x_vals=[2**i for i in range(4, 11)], line_arg="provider", - line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], - line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], + line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], + line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], styles=[("red", "-"), ("blue", "-")], ylabel="ms", plot_name=f"rotary_emb-batch-{BATCH}", @@ -110,16 +110,11 @@ def benchmark_rotary_emb( num_tokens: int, num_kv_heads: int, ): - BATCH_SIZE = 4 - SEQ_LEN = num_tokens // BATCH_SIZE - max_num_blocks_per_seq = 8 - block_size = 64 warmup = 10 rep = 100 - head_dim = 256 + head_dim = 128 dtype = torch.float16 - q_shape = (num_tokens, num_kv_heads, head_dim) q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") k_shape = (num_tokens, num_kv_heads, head_dim) @@ -127,26 +122,11 @@ def benchmark_rotary_emb( cos_shape = (num_tokens, head_dim // 2) cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) - k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") - v = torch.randn_like(k) - v_cache = torch.zeros_like(k_cache) - past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") - block_tables = mock_alloc_block_table_and_kvcache_v2( - k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size - ) - new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda") - new_q = torch.randn_like(new_k) - kv_seq_lengths = past_kv_seq_lengths + 1 - block_tables = block_tables.to(device="cuda") - if provider == "no_fused_rotary_emb_func": - fn = lambda: [ - rotary_embedding(new_q, new_k, cos, sin), - copy_kv_to_blocked_cache(new_k, k_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables), - ] - elif provider == "fused_triton_rotary_emb_func": - fn = lambda: rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths) + if provider == "torch_rotary_emb_func": + fn = lambda: torch_rotary_emb(q, cos, sin) + elif provider == "triton_rotary_emb_func": + fn = lambda: rotary_embedding(q, k, cos, sin) else: raise ValueError("Undefined provider") @@ -155,5 +135,5 @@ def benchmark_rotary_emb( if __name__ == "__main__": - # test_rotary_emb(4, 64, 32, 64, torch.float32) - benchmark_rotary_emb.run(save_path=".", print_data=True) + test_rotary_emb(4, 64, 32, 64, torch.float32) + # benchmark_rotary_emb.run(save_path=".",print_data=True)