[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)

* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator

* refactor decode_kv_cache_memcpy

* enable alibi in pagedattention
This commit is contained in:
Steve Luo
2024-04-30 15:52:23 +08:00
committed by GitHub
parent 5f00002e43
commit 5cd75ce4c7
14 changed files with 368 additions and 235 deletions

View File

@@ -24,14 +24,15 @@ __global__ void context_kv_cache_memcpy_kernel(
const int batch_size,
const int block_table_stride,
const int64_t key_stride,
const int64_t value_stride
const int64_t value_stride,
const int x
)
{
const int seq_token_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
if (block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
return ;
}
@@ -40,23 +41,33 @@ __global__ void context_kv_cache_memcpy_kernel(
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
int head_id;
int head_offset;
int x_id;
int x_offset;
int64_t key_src_id;
int64_t value_src_id;
int64_t target_id;
int64_t target_key_id;
int64_t target_value_id;
int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
head_id = i / head_dim;
head_offset = i % head_dim;
x_id = head_offset / x;
x_offset = head_offset % x;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id);
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id);
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_key_id);
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_value_id);
}
// tail process
@@ -64,14 +75,21 @@ __global__ void context_kv_cache_memcpy_kernel(
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
x_id = head_offset / x;
x_offset = head_offset % x;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
value_cache[target_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
key_cache[target_key_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
value_cache[target_value_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
}
}
@@ -81,7 +99,7 @@ template<typename T, typename CacheT>
void apply_context_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, head_num, head_dim]
torch::Tensor& value, // [num_tokens, head_num, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& cu_seqlens, // [batch_size + 1]
@@ -91,7 +109,8 @@ void apply_context_kv_cache_memcpy(
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int batch_size = block_tables.size(0);
int64_t key_stride = key.stride(0);
@@ -127,7 +146,8 @@ void apply_context_kv_cache_memcpy(
batch_size, \
block_table_stride, \
key_stride, \
value_stride \
value_stride, \
x \
); \
} while(0)
@@ -164,7 +184,7 @@ void apply_context_kv_cache_memcpy(
void context_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, head_num, head_dim]
torch::Tensor& value, // [num_tokens, head_num, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& cu_seqlens, // [batch_size + 1]