mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)
* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator * refactor decode_kv_cache_memcpy * enable alibi in pagedattention
This commit is contained in:
@@ -1,18 +1,19 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void decode_kv_cache_memcpy(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
|
||||
void context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
@@ -27,12 +28,13 @@ void rotary_embedding(
|
||||
bool high_precision);
|
||||
|
||||
void rotary_embedding_and_cache_copy(
|
||||
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
|
||||
torch::Tensor& cos, // [num_tokens, head_dim]
|
||||
torch::Tensor& sin, // [num_tokens, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
|
||||
torch::Tensor& cos, // [num_tokens, head_dim]
|
||||
torch::Tensor& sin, // [num_tokens, head_dim]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
@@ -71,7 +73,7 @@ void flash_decoding_attention(
|
||||
torch::Tensor&
|
||||
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
|
||||
float scale);
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float scale);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
|
Reference in New Issue
Block a user