From 12f10d5b0b49a180bc162e166337942e0bbfb96b Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 23 Apr 2024 13:44:49 +0800 Subject: [PATCH] [Fix/Inference]Fix CUDA Rotary Rmbedding GQA (#5623) * fix rotary embedding GQA * change test_rotary_embdding_unpad.py KH --- .../csrc/cuda/fused_rotary_emb_and_cache_kernel.cu | 4 ++-- .../test_ops/cuda/test_rotary_embdding_unpad.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu index 4f589597f..29715ca22 100644 --- a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -115,7 +115,7 @@ __device__ void apply_k_rotary_emb_compute( (head_offset % shard_block_size) / VecSize; const int64_t addr_offset = token_id * key_stride + (i / half_head_dim) * head_dim + head_offset; - const int64_t target_id = block_id * head_num * head_dim * block_size + + const int64_t target_id = block_id * kv_head_num * head_dim * block_size + (i / half_head_dim) * block_size * head_dim + block_offset * head_dim + head_offset; @@ -137,7 +137,7 @@ __device__ void apply_k_rotary_emb_compute( // apply value memcopy apply_kv_memcopy( - value, value_cache, value_stride, token_id, block_id, head_num * head_dim, + value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim, block_size, block_offset, head_dim, half_head_dim); } diff --git a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py index 9e0a8b0db..6f5d0ac84 100644 --- a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py @@ -21,9 +21,10 @@ def numpy_allclose(x, y, rtol, atol): @pytest.mark.parametrize("BATCH_SIZE", [4]) @pytest.mark.parametrize("SEQ_LEN", [64]) @pytest.mark.parametrize("H", [32]) +@pytest.mark.parametrize("K_H", [16, 32]) @pytest.mark.parametrize("D", [64]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) -def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): +def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype): torch.manual_seed(10) TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN # our crafted op equals to Transformers @@ -43,12 +44,12 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size q_shape = (TOTAL_TOKENS, H, D) q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") - k_shape = (TOTAL_TOKENS, H, D) + k_shape = (TOTAL_TOKENS, K_H, D) k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda") cos_shape = (TOTAL_TOKENS, D // 2) cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D) + cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D) k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") v = torch.randn_like(k) v_cache = torch.zeros_like(k_cache) @@ -56,8 +57,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): block_tables = mock_alloc_block_table_and_kvcache_v2( k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size ) - new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") - new_q = torch.randn_like(new_k) + new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda") + new_q = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda") new_v = torch.randn_like(new_k) kv_seq_lengths = past_kv_seq_lengths + 1 @@ -123,4 +124,4 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): if __name__ == "__main__": - test_rotary_emb(16, 64, 4, 128, torch.float16) + test_rotary_emb(16, 64, 32, 16, 128, torch.float16)