mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
add paged-attetionv2: support seq length split across thread block (#5707)
This commit is contained in:
@@ -122,6 +122,8 @@ def benchmark_flash_decoding_attention(
|
||||
mid_output_lse = torch.empty(
|
||||
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
|
||||
)
|
||||
exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device)
|
||||
max_logits = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device)
|
||||
|
||||
if provider == "vllm_paged_decoding_attention":
|
||||
alibi_slopes = None
|
||||
@@ -166,7 +168,8 @@ def benchmark_flash_decoding_attention(
|
||||
BLOCK_SIZE,
|
||||
max_seq_len_across_batch,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
|
Reference in New Issue
Block a user