mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[Fix/Inference]Fix CUDA Rotary Rmbedding GQA (#5623)
* fix rotary embedding GQA * change test_rotary_embdding_unpad.py KH
This commit is contained in:
@@ -21,9 +21,10 @@ def numpy_allclose(x, y, rtol, atol):
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||
@pytest.mark.parametrize("SEQ_LEN", [64])
|
||||
@pytest.mark.parametrize("H", [32])
|
||||
@pytest.mark.parametrize("K_H", [16, 32])
|
||||
@pytest.mark.parametrize("D", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
||||
torch.manual_seed(10)
|
||||
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||
# our crafted op equals to Transformers
|
||||
@@ -43,12 +44,12 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
|
||||
q_shape = (TOTAL_TOKENS, H, D)
|
||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
k_shape = (TOTAL_TOKENS, H, D)
|
||||
k_shape = (TOTAL_TOKENS, K_H, D)
|
||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D)
|
||||
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
v = torch.randn_like(k)
|
||||
v_cache = torch.zeros_like(k_cache)
|
||||
@@ -56,8 +57,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||
new_q = torch.randn_like(new_k)
|
||||
new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda")
|
||||
new_q = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||
new_v = torch.randn_like(new_k)
|
||||
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
@@ -123,4 +124,4 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb(16, 64, 4, 128, torch.float16)
|
||||
test_rotary_emb(16, 64, 32, 16, 128, torch.float16)
|
||||
|
Reference in New Issue
Block a user