add paged-attetionv2: support seq length split across thread block (#5707)

This commit is contained in:
Steve Luo
2024-05-14 12:46:54 +08:00
committed by GitHub
parent 18d67d0e8e
commit 7806842f2d
8 changed files with 704 additions and 249 deletions

View File

@@ -72,7 +72,8 @@ void flash_decoding_attention(
int block_size, int max_context_len,
torch::Tensor&
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
torch::Tensor& exp_sums, // [num_tokens, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes, float scale);
void convert_fp8(torch::Tensor& input, torch::Tensor& output);