add paged-attetionv2: support seq length split across thread block (#5707)

This commit is contained in:
Steve Luo
2024-05-14 12:46:54 +08:00
committed by GitHub
parent 18d67d0e8e
commit 7806842f2d
8 changed files with 704 additions and 249 deletions

View File

@@ -122,6 +122,8 @@ def benchmark_flash_decoding_attention(
mid_output_lse = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)
exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device)
max_logits = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device)
if provider == "vllm_paged_decoding_attention":
alibi_slopes = None
@@ -166,7 +168,8 @@ def benchmark_flash_decoding_attention(
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
mid_output_lse,
exp_sums,
max_logits,
alibi_slopes,
sm_scale,
)