[Inference/Kernel] refactor kvcache manager and rotary_embedding and kvcache_memcpy oper… (#5663)

* refactor kvcache manager and rotary_embedding and kvcache_memcpy operator

* refactor decode_kv_cache_memcpy

* enable alibi in pagedattention
This commit is contained in:
Steve Luo
2024-04-30 15:52:23 +08:00
committed by GitHub
parent 5f00002e43
commit 5cd75ce4c7
14 changed files with 368 additions and 235 deletions

View File

@@ -24,14 +24,15 @@ __global__ void context_kv_cache_memcpy_kernel(
const int batch_size,
const int block_table_stride,
const int64_t key_stride,
const int64_t value_stride
const int64_t value_stride,
const int x
)
{
const int seq_token_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
if (block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
return ;
}
@@ -40,23 +41,33 @@ __global__ void context_kv_cache_memcpy_kernel(
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
int head_id;
int head_offset;
int x_id;
int x_offset;
int64_t key_src_id;
int64_t value_src_id;
int64_t target_id;
int64_t target_key_id;
int64_t target_value_id;
int i = threadIdx.x * VecSize;
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
head_id = i / head_dim;
head_offset = i % head_dim;
x_id = head_offset / x;
x_offset = head_offset % x;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_id);
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_id);
copy<T, CacheT, VecSize>(key + key_src_id, key_cache + target_key_id);
copy<T, CacheT, VecSize>(value + value_src_id, value_cache + target_value_id);
}
// tail process
@@ -64,14 +75,21 @@ __global__ void context_kv_cache_memcpy_kernel(
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
x_id = head_offset / x;
x_offset = head_offset % x;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
value_cache[target_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
key_cache[target_key_id] = CastFunctor<T, CacheT>()(key[key_src_id]);
value_cache[target_value_id] = CastFunctor<T, CacheT>()(value[value_src_id]);
}
}
@@ -81,7 +99,7 @@ template<typename T, typename CacheT>
void apply_context_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, head_num, head_dim]
torch::Tensor& value, // [num_tokens, head_num, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& cu_seqlens, // [batch_size + 1]
@@ -91,7 +109,8 @@ void apply_context_kv_cache_memcpy(
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int batch_size = block_tables.size(0);
int64_t key_stride = key.stride(0);
@@ -127,7 +146,8 @@ void apply_context_kv_cache_memcpy(
batch_size, \
block_table_stride, \
key_stride, \
value_stride \
value_stride, \
x \
); \
} while(0)
@@ -164,7 +184,7 @@ void apply_context_kv_cache_memcpy(
void context_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, head_num, head_dim]
torch::Tensor& value, // [num_tokens, head_num, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& cu_seqlens, // [batch_size + 1]

View File

@@ -20,7 +20,8 @@ __global__ void decode_kv_cache_memcpy_kernel(
const int block_size,
const int64_t key_stride,
const int64_t value_stride,
const int block_table_stride
const int block_table_stride,
const int x
)
{
const int seq_id = blockIdx.x;
@@ -38,28 +39,42 @@ __global__ void decode_kv_cache_memcpy_kernel(
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int x_id = head_offset / x;
const int x_offset = head_offset % x;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
const int64_t target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
const int64_t target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
copy_vector<scalar_t, VecSize>(key_cache + target_key_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_value_id, value + value_src_id);
}
if (!Aligned) {
for (; i < hidden_size; ++i ) {
const int head_id = i / head_dim;
const int head_offset = i % head_dim;
const int x_id = head_offset / x;
const int x_offset = head_offset % x;
const int64_t key_src_id = seq_id * key_stride + i;
const int64_t value_src_id = seq_id * value_stride + i;
const int64_t target_id = block_id * hidden_size * block_size
const int64_t target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
const int64_t target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;
key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
key_cache[target_key_id] = key[key_src_id];
value_cache[target_value_id] = value[value_src_id];
}
}
@@ -69,7 +84,7 @@ template<typename scalar_t>
void apply_decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
@@ -77,7 +92,8 @@ void apply_decode_kv_cache_memcpy(
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int64_t key_stride = key.stride(0);
int64_t value_stride = value.stride(0);
@@ -110,7 +126,8 @@ void apply_decode_kv_cache_memcpy(
block_size, \
key_stride, \
value_stride, \
block_table_stride \
block_table_stride, \
x \
); \
} while(0)
@@ -146,7 +163,7 @@ void apply_decode_kv_cache_memcpy(
void decode_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]

View File

@@ -67,6 +67,7 @@ __global__ void flash_decoding_attention_kernel(
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const int* __restrict__ context_lens, // [num_tokens]
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
const float* __restrict__ alibi_slopes, // [num_heads]
const int max_seq_len,
const int num_kv_heads,
const float scale,
@@ -105,6 +106,7 @@ __global__ void flash_decoding_attention_kernel(
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
const int context_len = context_lens[seq_idx];
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int thread_group_offset = lane % NUM_THREADS_PER_X;
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
@@ -164,6 +166,7 @@ __global__ void flash_decoding_attention_kernel(
if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
@@ -261,6 +264,7 @@ __global__ void flash_decoding_attention_kernel(
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
context_lens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
alibi_slopes_ptr, \
max_context_len, \
num_kv_heads, \
scale, \
@@ -282,7 +286,8 @@ void flash_decoding_attention_v1_launcher(
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
int max_context_len,
float scale) {
float scale,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_tokens = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -304,6 +309,10 @@ void flash_decoding_attention_v1_launcher(
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
dim3 grid(num_heads, num_tokens, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
@@ -336,7 +345,8 @@ void flash_decoding_attention_v1_launcher(
context_lens, \
block_tables, \
max_context_len, \
scale);
scale, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
@@ -367,6 +377,7 @@ void flash_decoding_attention(
int max_context_len,
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes,
float scale) {

View File

@@ -91,7 +91,7 @@ __device__ void apply_k_rotary_emb_compute(
const int* __restrict__ block_tables, const int64_t key_stride,
const int64_t value_stride, const int token_id,
const int block_table_stride, const int head_num, const int head_dim,
const int kv_head_num, const int block_size, const int half_head_dim,
const int kv_head_num, const int block_size, const int x, const int half_head_dim,
const int shard_block_size) {
const int seq_len = sequence_lengths[token_id] - 1;
const int block_offset = seq_len % block_size;
@@ -102,36 +102,40 @@ __device__ void apply_k_rotary_emb_compute(
return;
}
scalar_t x[VecSize];
scalar_t y[VecSize];
scalar_t x0[VecSize];
scalar_t x1[VecSize];
scalar_t out_x[VecSize];
scalar_t out_y[VecSize];
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
i += blockDim.x * VecSize) {
const int head_offset = i % half_head_dim;
const int half_head_offset = i % half_head_dim;
const int x_id = half_head_offset / x;
const int x_offset = half_head_offset % x;
const int shard_offset =
(head_offset / shard_block_size) * shard_block_size +
(head_offset % shard_block_size) / VecSize;
(half_head_offset / shard_block_size) * shard_block_size +
(half_head_offset % shard_block_size) / VecSize;
const int64_t addr_offset =
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
const int64_t target_id = block_id * kv_head_num * head_dim * block_size +
(i / half_head_dim) * block_size * head_dim +
block_offset * head_dim + head_offset;
token_id * key_stride + (i / half_head_dim) * head_dim + half_head_offset;
const int64_t target_id = block_id * kv_head_num * head_dim * block_size
+ (i / half_head_dim) * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
copy_vector<scalar_t, VecSize>(x, key + addr_offset);
copy_vector<scalar_t, VecSize>(y, key + addr_offset + half_head_dim);
copy_vector<scalar_t, VecSize>(x0, key + addr_offset);
copy_vector<scalar_t, VecSize>(x1, key + addr_offset + half_head_dim);
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x0[j]) * cos_ptr[j * 32 + shard_offset] -
static_cast<m_scalar_t>(x1[j]) * sin_ptr[j * 32 + shard_offset]);
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x1[j]) * cos_ptr[j * 32 + shard_offset] +
static_cast<m_scalar_t>(x0[j]) * sin_ptr[j * 32 + shard_offset]);
}
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim,
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim * block_size,
out_y);
}
@@ -162,7 +166,8 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
const int head_num,
const int head_dim,
const int kv_head_num,
const int block_size
const int block_size,
const int x
) {
const int token_id = blockIdx.x;
@@ -182,7 +187,7 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
//compute key and copy kv
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size);
}
template<typename scalar_t, typename m_scalar_t, int VecSize>
@@ -220,6 +225,31 @@ __global__ void rotary_embedding_kernel(
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
}
#define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, VEC_SIZE><<<grid, block, shared_memory_size, stream>>>( \
query.data_ptr<scalar_t>(), \
key.data_ptr<scalar_t>(), \
value.data_ptr<scalar_t>(), \
cos.data_ptr<scalar_t>(), \
sin.data_ptr<scalar_t>(), \
key_cache.data_ptr<scalar_t>(), \
value_cache.data_ptr<scalar_t>(), \
sequence_lengths.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
query_stride, \
key_stride, \
value_stride, \
shard_element_num / 2, \
cos_stride, \
sin_stride, \
block_table_stride, \
head_num, \
head_dim, \
kv_head_num, \
block_size, \
x); \
template<typename scalar_t, bool high_precision>
void apply_rotary_embedding_and_cache_copy(
at::Tensor& query, // [num_tokens, head_num, head_dim]
@@ -227,7 +257,7 @@ void apply_rotary_embedding_and_cache_copy(
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables) // [batch_size, max_seq_len]
@@ -236,7 +266,8 @@ void apply_rotary_embedding_and_cache_copy(
int head_num = query.size(1);
int head_dim = query.size(2);
int kv_head_num = key.size(1);
int block_size = key_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int64_t query_stride = query.stride(0);
int64_t key_stride = key.stride(0);
@@ -261,80 +292,18 @@ void apply_rotary_embedding_and_cache_copy(
dim3 grid(num_tokens);
dim3 block(std::min(thread_nums, 512));
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size;
const int shared_memory_size = shard_element_num * sizeof(m_scalar_t);
switch (vec_size) {
case 1:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(1);
break;
case 2:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(2);
break;
case 4:
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
cos.data_ptr<scalar_t>(),
sin.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
query_stride,
key_stride,
value_stride,
shard_element_num / 2,
cos_stride,
sin_stride,
block_table_stride,
head_num,
head_dim,
kv_head_num,
block_size
);
ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(4);
break;
default:
AT_ERROR("Unsupported vectorized size ", vec_size);
@@ -441,7 +410,7 @@ void rotary_embedding_and_cache_copy(
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
at::Tensor& cos, // [num_tokens, head_dim]
at::Tensor& sin, // [num_tokens, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& block_tables, // [batch_size, max_seq_len]

View File

@@ -1,18 +1,19 @@
#include <torch/extension.h>
void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
@@ -27,12 +28,13 @@ void rotary_embedding(
bool high_precision);
void rotary_embedding_and_cache_copy(
torch::Tensor& query, // [num_tokens, head_num, head_dim]
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
torch::Tensor& cos, // [num_tokens, head_dim]
torch::Tensor& sin, // [num_tokens, head_dim]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& query, // [num_tokens, head_num, head_dim]
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
torch::Tensor& cos, // [num_tokens, head_dim]
torch::Tensor& sin, // [num_tokens, head_dim]
torch::Tensor&
key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_dim]
torch::Tensor& sequence_lengths, // [batch_size]
@@ -71,7 +73,7 @@ void flash_decoding_attention(
torch::Tensor&
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
float scale);
const c10::optional<torch::Tensor>& alibi_slopes, float scale);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,