mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
This commit is contained in:
@@ -24,14 +24,15 @@ __global__ void context_kv_cache_memcpy_kernel(
|
||||
const int batch_size,
|
||||
const int block_table_stride,
|
||||
const int64_t key_stride,
|
||||
const int64_t value_stride
|
||||
const int64_t value_stride,
|
||||
const int x
|
||||
)
|
||||
{
|
||||
const int seq_token_id = blockIdx.x;
|
||||
const int seq_id = blockIdx.y;
|
||||
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
|
||||
|
||||
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
|
||||
if (block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
|
||||
return ;
|
||||
}
|
||||
|
||||
@@ -40,23 +41,33 @@ __global__ void context_kv_cache_memcpy_kernel(
|
||||
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
|
||||
int head_id;
|
||||
int head_offset;
|
||||
int x_id;
|
||||
int x_offset;
|
||||
int64_t key_src_id;
|
||||
int64_t value_src_id;
|
||||
int64_t target_id;
|
||||
int64_t target_key_id;
|
||||
int64_t target_value_id;
|
||||
|
||||
int i = threadIdx.x * VecSize;
|
||||
|
||||
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
|
||||
head_id = i / head_dim;
|
||||
head_offset = i % head_dim;
|
||||
x_id = head_offset / x;
|
||||
x_offset = head_offset % x;
|
||||
key_src_id = total_token_id * key_stride + i;
|
||||
value_src_id = total_token_id * value_stride + i;
|
||||
target_id = block_id * hidden_size * block_size
|
||||
target_key_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
target_value_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id);
|
||||
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id);
|
||||
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_key_id);
|
||||
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_value_id);
|
||||
}
|
||||
|
||||
// tail process
|
||||
@@ -64,14 +75,21 @@ __global__ void context_kv_cache_memcpy_kernel(
|
||||
for (; i < hidden_size; ++i ) {
|
||||
head_id = i / head_dim;
|
||||
head_offset = i % head_dim;
|
||||
x_id = head_offset / x;
|
||||
x_offset = head_offset % x;
|
||||
key_src_id = total_token_id * key_stride + i;
|
||||
value_src_id = total_token_id * value_stride + i;
|
||||
target_id = block_id * hidden_size * block_size
|
||||
target_key_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ x_id * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
target_value_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
|
||||
value_cache[target_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
|
||||
key_cache[target_key_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
|
||||
value_cache[target_value_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +99,7 @@ template<typename T, typename CacheT>
|
||||
void apply_context_kv_cache_memcpy(
|
||||
torch::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
@@ -91,7 +109,8 @@ void apply_context_kv_cache_memcpy(
|
||||
int num_tokens = key.size(0);
|
||||
int head_num = key.size(1);
|
||||
int head_dim = key.size(2);
|
||||
int block_size = key_cache.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
int batch_size = block_tables.size(0);
|
||||
|
||||
int64_t key_stride = key.stride(0);
|
||||
@@ -127,7 +146,8 @@ void apply_context_kv_cache_memcpy(
|
||||
batch_size, \
|
||||
block_table_stride, \
|
||||
key_stride, \
|
||||
value_stride \
|
||||
value_stride, \
|
||||
x \
|
||||
); \
|
||||
} while(0)
|
||||
|
||||
@@ -164,7 +184,7 @@ void apply_context_kv_cache_memcpy(
|
||||
void context_kv_cache_memcpy(
|
||||
torch::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
|
Reference in New Issue
Block a user