mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
This commit is contained in:
@@ -67,6 +67,7 @@ __global__ void flash_decoding_attention_kernel(
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const int* __restrict__ context_lens, // [num_tokens]
|
||||
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int max_seq_len,
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
@@ -105,6 +106,7 @@ __global__ void flash_decoding_attention_kernel(
|
||||
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
const int thread_group_offset = lane % NUM_THREADS_PER_X;
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
@@ -164,6 +166,7 @@ __global__ void flash_decoding_attention_kernel(
|
||||
|
||||
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
|
||||
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
@@ -261,6 +264,7 @@ __global__ void flash_decoding_attention_kernel(
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
context_lens.data_ptr<int>(), \
|
||||
block_tables.data_ptr<int>(), \
|
||||
alibi_slopes_ptr, \
|
||||
max_context_len, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
@@ -282,7 +286,8 @@ void flash_decoding_attention_v1_launcher(
|
||||
torch::Tensor& context_lens, // [num_tokens]
|
||||
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
|
||||
int max_context_len,
|
||||
float scale) {
|
||||
float scale,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_tokens = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@@ -304,6 +309,10 @@ void flash_decoding_attention_v1_launcher(
|
||||
// Keep that in sync with the logic here!
|
||||
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
|
||||
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
dim3 grid(num_heads, num_tokens, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
@@ -336,7 +345,8 @@ void flash_decoding_attention_v1_launcher(
|
||||
context_lens, \
|
||||
block_tables, \
|
||||
max_context_len, \
|
||||
scale);
|
||||
scale, \
|
||||
alibi_slopes);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
@@ -367,6 +377,7 @@ void flash_decoding_attention(
|
||||
int max_context_len,
|
||||
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float scale) {
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user