[Inference/Kernel] Add Paged Decoding kernel, sequence split within the same thread block (#5531)

* feat flash decoding for paged attention

* refactor flashdecodingattention

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Steve Luo
2024-04-18 16:45:07 +08:00
committed by GitHub
parent 56b222eff8
commit be396ad6cc
15 changed files with 1765 additions and 211 deletions

View File

@@ -4,8 +4,8 @@ from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
prepare_padding_mask,
torch_attn_ref,
)
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
@@ -67,9 +67,18 @@ def bench_kernel(
if provider == "torch":
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b)
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device)
torch_padding_mask = create_attention_mask(kv_lengths, bsz, Q_LEN, max_seq_len_in_b, q.device)
fn = lambda: torch_attn_ref(
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
q,
k_torch,
v_torch,
torch_padding_mask,
bsz,
Q_LEN,
max_seq_len_in_b,
num_attn_heads,
num_kv_heads,
HEAD_DIM,
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
if provider == "triton":

View File

@@ -0,0 +1,173 @@
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.kernel_utils import (
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_vllm,
)
try:
import triton # noqa
except ImportError:
print("please install triton from https://github.com/openai/triton")
inference_ops = InferenceOpsLoader().load()
# Triton benchmark plot attributions
configs = [
triton.testing.Benchmark(
x_names=["MAX_NUM_BLOCKS_PER_SEQ"],
x_vals=[2**i for i in range(3, 8)],
line_arg="provider",
line_vals=[
"vllm_paged_decoding_attention",
"triton_flash_decoding_attention",
"cuda_flash_decoding_attention",
],
line_names=[
"vllm_paged_decoding_attention",
"triton_flash_decoding_attention",
"cuda_flash_decoding_attention",
],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
ylabel="ms",
plot_name=f"FlashDecodingAttention benchmarking results",
args={"BATCH_SIZE": 16, "BLOCK_SIZE": 32, "HEAD_SIZE": 128, "KV_GROUP_NUM": 2},
)
]
def prepare_data(
BATCH_SIZE: int,
HEAD_SIZE: int,
NUM_ATTN_HEADS: int,
NUM_KV_HEADS: int,
MAX_SEQ_LEN: int,
dtype=torch.float16,
device="cuda",
):
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
# otherwise generate random context lengths.
# returns
# q [BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE]
# k_unpad/v_unpad [num_tokens, NUM_KV_HEADS, HEAD_SIZE]
kv_lengths = torch.randint(low=1, high=MAX_SEQ_LEN, size=(BATCH_SIZE,), dtype=torch.int32, device=device)
num_tokens = torch.sum(kv_lengths).item()
q_size = (BATCH_SIZE, 1, NUM_ATTN_HEADS, HEAD_SIZE)
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
kv_size = (num_tokens, 2 * NUM_KV_HEADS, HEAD_SIZE)
kv_unpad = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
k_unpad, v_unpad = torch.split(kv_unpad, [NUM_KV_HEADS, NUM_KV_HEADS], dim=-2)
return q, k_unpad, v_unpad, kv_lengths
@triton.testing.perf_report(configs)
def benchmark_flash_decoding_attention(
provider: str,
BATCH_SIZE: int,
BLOCK_SIZE: int,
MAX_NUM_BLOCKS_PER_SEQ: int,
HEAD_SIZE: int,
KV_GROUP_NUM: int,
):
try:
from vllm._C import ops as vllm_ops
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
warmup = 10
rep = 1000
dtype = torch.float16
NUM_ATTN_HEADS = 16
NUM_KV_HEADS = NUM_ATTN_HEADS // KV_GROUP_NUM
assert isinstance(NUM_KV_HEADS, int) and NUM_KV_HEADS > 0, "Invalid number of kv heads."
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
vllm_k_cache, vllm_v_cache, _ = generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
mid_output_lse = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)
if provider == "vllm_paged_decoding_attention":
alibi_slopes = None
fn = lambda: vllm_ops.paged_attention_v1(
output,
q.squeeze(2),
vllm_k_cache,
vllm_v_cache,
NUM_KV_HEADS,
sm_scale,
block_tables,
kv_seq_lengths,
BLOCK_SIZE,
max_seq_len_across_batch,
alibi_slopes,
"auto",
)
elif provider == "triton_flash_decoding_attention":
fn = lambda: flash_decoding_attention(
q.squeeze(2),
k_cache,
v_cache,
kv_seq_lengths,
block_tables,
BLOCK_SIZE,
max_seq_len_across_batch,
output,
mid_output,
mid_output_lse,
sm_scale=sm_scale,
kv_group_num=KV_GROUP_NUM,
) # [bsz, 1, num_heads, head_dim]
elif provider == "cuda_flash_decoding_attention":
fn = lambda: inference_ops.flash_decoding_attention(
output,
q.squeeze(2),
k_cache,
v_cache,
kv_seq_lengths,
block_tables,
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
mid_output_lse,
sm_scale,
)
else:
raise ValueError("Undefined provider.")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if __name__ == "__main__":
benchmark_flash_decoding_attention.run(save_path=".", print_data=True)