[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)

* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator

* refactor decode_kv_cache_memcpy

* enable alibi in pagedattention
This commit is contained in:
Steve Luo
2024-04-30 15:52:23 +08:00
committed by GitHub
parent 5f00002e43
commit 5cd75ce4c7
14 changed files with 368 additions and 235 deletions

View File

@@ -67,6 +67,7 @@ __global__ void flash_decoding_attention_kernel(
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const int* __restrict__ context_lens, // [num_tokens]
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
const float* __restrict__ alibi_slopes, // [num_heads]
const int max_seq_len,
const int num_kv_heads,
const float scale,
@@ -105,6 +106,7 @@ __global__ void flash_decoding_attention_kernel(
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
const int context_len = context_lens[seq_idx];
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int thread_group_offset = lane % NUM_THREADS_PER_X;
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
@@ -164,6 +166,7 @@ __global__ void flash_decoding_attention_kernel(
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@@ -261,6 +264,7 @@ __global__ void flash_decoding_attention_kernel(
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
context_lens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
alibi_slopes_ptr, \
max_context_len, \
num_kv_heads, \
scale, \
@@ -282,7 +286,8 @@ void flash_decoding_attention_v1_launcher(
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
int max_context_len,
float scale) {
float scale,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_tokens = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -304,6 +309,10 @@ void flash_decoding_attention_v1_launcher(
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
dim3 grid(num_heads, num_tokens, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
@@ -336,7 +345,8 @@ void flash_decoding_attention_v1_launcher(
context_lens, \
block_tables, \
max_context_len, \
scale);
scale, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
@@ -367,6 +377,7 @@ void flash_decoding_attention(
int max_context_len,
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes,
float scale) {