[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)

* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator

* refactor decode_kv_cache_memcpy

* enable alibi in pagedattention
This commit is contained in:
Steve Luo
2024-04-30 15:52:23 +08:00
committed by GitHub
parent 5f00002e43
commit 5cd75ce4c7
14 changed files with 368 additions and 235 deletions

View File

@@ -20,7 +20,8 @@ __global__ void decode_kv_cache_memcpy_kernel(
const int block_size,
const int64_t key_stride,
const int64_t value_stride,
const int block_table_stride
const int block_table_stride,
const int x
)
{
const int seq_id = blockIdx.x;
@@ -38,28 +39,42 @@ __global__ void decode_kv_cache_memcpy_kernel(
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int x_id = head_offset / x;
const int x_offset = head_offset % x;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
const int64_t target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
const int64_t target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
copy_vector<scalar_t, VecSize>(key_cache + target_key_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_value_id, value + value_src_id);
}
if (!Aligned) {
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int x_id = head_offset / x;
const int x_offset = head_offset % x;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
const int64_t target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
const int64_t target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
key_cache[target_key_id] = key[key_src_id];
value_cache[target_value_id] = value[value_src_id];
}
}
@@ -69,7 +84,7 @@ template<typename scalar_t>
void apply_decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
@@ -77,7 +92,8 @@ void apply_decode_kv_cache_memcpy(
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
@@ -110,7 +126,8 @@ void apply_decode_kv_cache_memcpy(
block_size, \
key_stride, \
value_stride, \
block_table_stride \
block_table_stride, \
x \
); \
} while(0)
@@ -146,7 +163,7 @@ void apply_decode_kv_cache_memcpy(
void decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]