mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
This commit is contained in:
@@ -20,7 +20,8 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
||||
const int block_size,
|
||||
const int64_t key_stride,
|
||||
const int64_t value_stride,
|
||||
const int block_table_stride
|
||||
const int block_table_stride,
|
||||
const int x
|
||||
)
|
||||
{
|
||||
const int seq_id = blockIdx.x;
|
||||
@@ -38,28 +39,42 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
||||
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int x_id = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
const int64_t value_src_id = seq_id * value_stride + i;
|
||||
const int64_t target_id = block_id * hidden_size * block_size
|
||||
const int64_t target_key_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t target_value_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_key_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_value_id, value + value_src_id);
|
||||
}
|
||||
|
||||
if (!Aligned) {
|
||||
for (; i < hidden_size; ++i ) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int x_id = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
const int64_t value_src_id = seq_id * value_stride + i;
|
||||
const int64_t target_id = block_id * hidden_size * block_size
|
||||
const int64_t target_key_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t target_value_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = key[key_src_id];
|
||||
value_cache[target_id] = value[value_src_id];
|
||||
key_cache[target_key_id] = key[key_src_id];
|
||||
value_cache[target_value_id] = value[value_src_id];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +84,7 @@ template<typename scalar_t>
|
||||
void apply_decode_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
@@ -77,7 +92,8 @@ void apply_decode_kv_cache_memcpy(
|
||||
int num_tokens = key.size(0);
|
||||
int head_num = key.size(1);
|
||||
int head_dim = key.size(2);
|
||||
int block_size = key_cache.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
@@ -110,7 +126,8 @@ void apply_decode_kv_cache_memcpy(
|
||||
block_size, \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
block_table_stride \
|
||||
block_table_stride, \
|
||||
x \
|
||||
); \
|
||||
} while(0)
|
||||
|
||||
@@ -146,7 +163,7 @@ void apply_decode_kv_cache_memcpy(
|
||||
void decode_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
|
Reference in New Issue
Block a user