[Inference/Kernel] Add Paged Decoding kernel, sequence split within the same thread block (#5531)

* feat flash decoding for paged attention

* refactor flashdecodingattention

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Steve Luo
2024-04-18 16:45:07 +08:00
committed by GitHub
parent 56b222eff8
commit be396ad6cc
15 changed files with 1765 additions and 211 deletions

View File

@@ -0,0 +1,274 @@
from itertools import product
import numpy as np
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_vllm,
torch_attn_ref,
)
q_len = 1
def prepare_data(
BATCH_SIZE: int,
HEAD_SIZE: int,
NUM_ATTN_HEADS: int,
NUM_KV_HEADS: int,
MAX_SEQ_LEN: int,
dtype=torch.float16,
device="cuda",
):
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths.
# returns
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
num_tokens = torch.sum(kv_lengths).item()
q_size = (BATCH_SIZE, q_len, NUM_ATTN_HEADS, HEAD_SIZE)
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)
return q, k_unpad, v_unpad, kv_lengths
def numpy_allclose(x, y, rtol, atol):
x_numpy = x.detach().cpu().numpy()
y_numpy = y.detach().cpu().numpy()
np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
mid_output_lse = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
inference_ops.flash_decoding_attention(
output,
q.squeeze(2),
k_cache,
v_cache,
kv_seq_lengths,
block_tables,
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
mid_output_lse,
sm_scale,
)
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_vllm_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
try:
from vllm._C import ops as vllm_ops
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
else:
rtol = 1e-5
atol = 1e-7
out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
alibi_slopes = None
vllm_ops.paged_attention_v1(
output,
q.squeeze(2),
k_cache,
v_cache,
NUM_KV_HEADS,
sm_scale,
block_tables,
kv_seq_lengths,
BLOCK_SIZE,
max_seq_len_across_batch,
alibi_slopes,
"auto",
)
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
if __name__ == "__main__":
BATCH_SIZE = [1, 4, 7, 32]
BLOCK_SIZE = [8, 16, 32]
MAX_NUM_BLOCKS_PER_SEQ = [1, 8, 32]
HEAD_SIZE = [64, 128]
NUM_ATTN_HEADS = [16]
KV_GROUP_NUM = [1, 2, 16]
DTYPE = [torch.float16, torch.float32]
test_combinations = list(
product(BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, DTYPE)
)
for (
batch_size,
block_size,
max_num_blocks_per_seq,
head_size,
num_attn_heads,
kv_group_num,
dtype,
) in test_combinations:
test_flash_decoding_attention(
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype
)

View File

@@ -150,6 +150,51 @@ def mock_alloc_block_table_and_kvcache_v2(
return block_tables
def mock_alloc_block_table_and_kvcache_vllm(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
_, num_kv_heads, head_dim = k.shape
x = 16 // torch.tensor([], dtype=k.dtype).element_size()
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
# [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x]
k_block = (
k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :]
.reshape(allocated_locs, num_kv_heads, head_dim // x, x)
.permute(1, 2, 0, 3)
)
# [block_size, num_kv_heads, head_dim]->[num_kv_heads, head_dim, block_size]
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 2, 0)
k_cache[block_id, :, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :, :allocated_locs] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_single_token(block_tables: torch.Tensor, context_lengths: torch.Tensor, block_size: int) -> None:
# Allocate 1 token on the block table for each seqs in block tables.
# It won't change provided context_lengths.
@@ -206,6 +251,26 @@ def generate_caches_and_block_tables_v2(
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_vllm(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def convert_kv_unpad_to_padded(
k_unpad: torch.Tensor, kv_seq_lengths: torch.Tensor, bsz: int, max_seq_len: int
) -> torch.Tensor: