mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
This commit is contained in:
@@ -4,8 +4,10 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
@@ -60,8 +62,9 @@ def numpy_allclose(x, y, rtol, atol):
|
||||
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
|
||||
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
def test_flash_decoding_attention(
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
@@ -73,6 +76,11 @@ def test_flash_decoding_attention(
|
||||
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
|
||||
device = get_current_device()
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
|
||||
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
|
||||
)
|
||||
@@ -91,6 +99,15 @@ def test_flash_decoding_attention(
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
|
||||
torch_padding_mask = torch_padding_mask + alibi_mask
|
||||
|
||||
if len(torch_padding_mask.size()) == 4:
|
||||
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
|
||||
else:
|
||||
torch_padding_mask = torch_padding_mask[:, -1:, :]
|
||||
|
||||
mid_output = torch.empty(
|
||||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
|
||||
)
|
||||
@@ -146,8 +163,14 @@ def test_flash_decoding_attention(
|
||||
max_seq_len_across_batch,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
# The alibi may introduce relatively large errors
|
||||
if use_alibi_slopes:
|
||||
rtol = 1e0
|
||||
|
||||
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@@ -168,8 +191,9 @@ except ImportError:
|
||||
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
|
||||
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
|
||||
def test_vllm_flash_decoding_attention(
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
|
||||
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
|
||||
):
|
||||
torch.manual_seed(123)
|
||||
torch.cuda.empty_cache()
|
||||
@@ -199,6 +223,18 @@ def test_vllm_flash_decoding_attention(
|
||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
|
||||
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
|
||||
|
||||
if use_alibi_slopes:
|
||||
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
|
||||
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
|
||||
torch_padding_mask = torch_padding_mask + alibi_mask
|
||||
|
||||
if len(torch_padding_mask.size()) == 4:
|
||||
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
|
||||
else:
|
||||
torch_padding_mask = torch_padding_mask[:, -1:, :]
|
||||
else:
|
||||
alibi_slopes = None
|
||||
|
||||
if dtype == torch.float16:
|
||||
rtol = 1e-3
|
||||
atol = 1e-3
|
||||
@@ -236,8 +272,6 @@ def test_vllm_flash_decoding_attention(
|
||||
HEAD_SIZE,
|
||||
)
|
||||
|
||||
alibi_slopes = None
|
||||
|
||||
vllm_ops.paged_attention_v1(
|
||||
output,
|
||||
q.squeeze(2),
|
||||
@@ -253,6 +287,11 @@ def test_vllm_flash_decoding_attention(
|
||||
"auto",
|
||||
kv_scale,
|
||||
)
|
||||
|
||||
# The alibi may introduce relatively large errors
|
||||
if use_alibi_slopes:
|
||||
rtol = 1e0
|
||||
|
||||
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@@ -277,5 +316,5 @@ if __name__ == "__main__":
|
||||
dtype,
|
||||
) in test_combinations:
|
||||
test_flash_decoding_attention(
|
||||
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype
|
||||
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
|
||||
)
|
||||
|
@@ -4,12 +4,40 @@ import torch.nn.functional as F
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.utils import get_current_device
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2
|
||||
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v3, mock_alloc_single_token
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
HEAD_DIM = 4
|
||||
HEAD_DIM = 72
|
||||
|
||||
|
||||
def prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
context_lengths,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
):
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
max_seq_len_in_batch = context_lengths.max()
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v3(
|
||||
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
|
||||
block_tables = block_tables.to(device=device)
|
||||
k_cache = torch.zeros_like(k_cache_ref)
|
||||
v_cache = torch.zeros_like(v_cache_ref)
|
||||
|
||||
return key, value, k_cache, v_cache, cu_seqlens, block_tables, max_seq_len_in_batch, k_cache_ref, v_cache_ref
|
||||
|
||||
|
||||
def run_decode_copy_kv_to_caches(
|
||||
@@ -24,32 +52,41 @@ def run_decode_copy_kv_to_caches(
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
n = 1
|
||||
|
||||
max_seq_len = block_size * max_num_blocks_per_seq
|
||||
dtype = torch.float32
|
||||
device = get_current_device()
|
||||
|
||||
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||
bsz,
|
||||
num_kv_heads,
|
||||
HEAD_DIM,
|
||||
block_size,
|
||||
max_num_blocks_per_seq,
|
||||
same_context_len,
|
||||
max_seq_len,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
assert max_seq_len > n, "max_seq_len must be greater than n"
|
||||
|
||||
past_kv_seq_lengths = (
|
||||
torch.tensor([max_seq_len - n for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||
if same_context_len
|
||||
else torch.randint(low=1, high=max_seq_len - n, size=(bsz,), dtype=torch.int32, device=device)
|
||||
)
|
||||
|
||||
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
|
||||
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||
key, value, k_cache, v_cache, _, block_tables, _, _, _ = prepare_data(
|
||||
bsz, num_kv_heads, block_size, max_num_blocks_per_seq, past_kv_seq_lengths, device, dtype
|
||||
)
|
||||
|
||||
past_kv_seq_len = kv_seq_lengths - 1
|
||||
new_k = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
new_v = torch.randn((bsz, num_kv_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
|
||||
# mock allocating blocks for the new k/v and update block tables
|
||||
for _ in range(n):
|
||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||
past_kv_seq_lengths += 1
|
||||
|
||||
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, past_kv_seq_lengths, block_tables)
|
||||
|
||||
past_kv_seq_len = past_kv_seq_lengths - 1
|
||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :]
|
||||
k_source = new_k.squeeze()
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||
k_target = k_target.reshape(v_target.shape)
|
||||
v_source = new_v.squeeze()
|
||||
|
||||
assert k_target.shape == k_source.shape
|
||||
@@ -77,22 +114,17 @@ def run_context_copy_kv_to_cache(
|
||||
else:
|
||||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
max_seq_len_in_batch = context_lengths.max()
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
|
||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||
key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||
)
|
||||
|
||||
block_tables = block_tables.to(device=device)
|
||||
k_cache = torch.zeros_like(k_cache_ref)
|
||||
v_cache = torch.zeros_like(v_cache_ref)
|
||||
(
|
||||
key,
|
||||
value,
|
||||
k_cache,
|
||||
v_cache,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_seq_len_in_batch,
|
||||
k_cache_ref,
|
||||
v_cache_ref,
|
||||
) = prepare_data(bsz, num_kv_heads, block_size, max_num_blocks_per_seq, context_lengths, device, dtype)
|
||||
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch
|
||||
|
@@ -7,7 +7,7 @@ from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v3
|
||||
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||
|
||||
|
||||
@@ -49,12 +49,14 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
||||
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
k_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, D // x, block_size, x)
|
||||
v_cache_shape = (BATCH_SIZE * max_blocks_per_sequence, K_H, block_size, D)
|
||||
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device="cuda")
|
||||
v = torch.randn_like(k)
|
||||
v_cache = torch.zeros_like(k_cache)
|
||||
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device="cuda")
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v3(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, K_H, D), dtype=dtype, device="cuda")
|
||||
@@ -97,9 +99,10 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
|
||||
past_kv_seq_len = kv_seq_lengths - 1
|
||||
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||
offsets_in_block = past_kv_seq_len % block_size
|
||||
k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||
k_target = k_cache[target_block_ids, :, :, offsets_in_block, :].squeeze()
|
||||
k_source = new_k_copy.squeeze()
|
||||
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||
k_target = k_target.reshape(v_target.shape)
|
||||
v_source = new_v.squeeze()
|
||||
|
||||
numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol)
|
||||
|
Reference in New Issue
Block a user