mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference/kernel]Add Fused Rotary Embedding and KVCache Memcopy CUDA Kernel (#5418)
* add rotary embedding kernel * add rotary_embedding_kernel * add fused rotary_emb and kvcache memcopy * add fused_rotary_emb_and_cache_kernel.cu * add fused_rotary_emb_and_memcopy * fix bugs in fused_rotary_emb_and_cache_kernel.cu * fix ci bugs * use vec memcopy and opt the gloabl memory access * fix code style * fix test_rotary_embdding_unpad.py * codes revised based on the review comments * fix bugs about include path * rm inline
This commit is contained in:
98
extensions/csrc/common/vector_copy_utils.h
Normal file
98
extensions/csrc/common/vector_copy_utils.h
Normal file
@@ -0,0 +1,98 @@
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cfloat>
|
||||
|
||||
#include "string"
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 2>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float *)dst) = *((float *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 8>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float4 *)dst) = *((float4 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float *)dst) = *((float *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float4 *)dst) = *((float4 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) {
|
||||
*((float4 *)dst) = *((float4 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
||||
// Since the maximum memory alignment length is 128 bits, we choose float4
|
||||
// here.
|
||||
*((float4 *)dst) = *((float4 *)src);
|
||||
*((float4 *)(dst + 4)) = *((float4 *)(src + 4));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int get_vec_size(const torch::Tensor &tensor) {
|
||||
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
|
||||
const int max_aligned_size = 128;
|
||||
const int dtype_size = sizeof(T) * 8;
|
||||
|
||||
const int vec_size = max_aligned_size / sizeof(T) / 8;
|
||||
|
||||
if (address % (dtype_size * 4) == 0) {
|
||||
return std::min(4, vec_size);
|
||||
} else if (address % (dtype_size * 2) == 0) {
|
||||
return std::min(2, vec_size);
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
@@ -39,6 +39,9 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
|
||||
auto ins_shape = ins.sizes().vec();
|
||||
|
||||
ins_shape[0] = ins_shape[0]/2;
|
||||
if (ins_shape[0] == 1) {
|
||||
ins_shape.erase(ins_shape.begin());
|
||||
}
|
||||
auto outs = torch::zeros(ins_shape,ins.options());
|
||||
auto outs_shape = ins.sizes().vec();
|
||||
|
||||
|
@@ -1,10 +1,10 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "../common/vector_copy_utils.h"
|
||||
#include "../common/micros.h"
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename scalar_t, int VecSize>
|
||||
__global__ void decode_kv_cache_memcpy_kernel(
|
||||
const scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ value,
|
||||
@@ -12,79 +12,146 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
||||
scalar_t* __restrict__ value_cache,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ block_tables,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int block_size,
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int64_t key_stride,
|
||||
const int64_t value_stride,
|
||||
const int block_table_stride
|
||||
)
|
||||
{
|
||||
const int seq_id = blockIdx.x;
|
||||
const int seq_len = sequence_lengths[seq_id] - 1;
|
||||
const int seq_id_in_block_table = seq_len / block_size;
|
||||
const int block_offset = seq_len % block_size;
|
||||
const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table];
|
||||
const int hidden_size = num_heads * head_size;
|
||||
const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size];
|
||||
const int hidden_size = head_num * head_dim;
|
||||
|
||||
if ( block_id < 0 ) {
|
||||
return ;
|
||||
}
|
||||
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
const int head_id = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int key_src_id = seq_id * key_stride + i;
|
||||
const int value_src_id = seq_id * value_stride + i;
|
||||
const int target_src_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_size
|
||||
+ block_offset * head_size + head_offset;
|
||||
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
const int64_t value_src_id = seq_id * value_stride + i;
|
||||
const int64_t target_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_src_id] = key[key_src_id];
|
||||
value_cache[target_src_id] = value[value_src_id];
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void decode_kv_cache_memcpy(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
template<typename scalar_t>
|
||||
void apply_decode_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int head_num = key.size(1);
|
||||
int head_dim = key.size(2);
|
||||
int block_size = key_cache.size(2);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(key);
|
||||
|
||||
if (head_dim % vec_size != 0) {
|
||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||
vec_size = 1;
|
||||
}
|
||||
|
||||
int thread_nums = head_num * head_dim / vec_size;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
key.scalar_type(),
|
||||
"decode_kv_cache_memcpy",
|
||||
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
key_stride,
|
||||
value_stride,
|
||||
block_table_stride
|
||||
);)
|
||||
dim3 block(std::min(thread_nums, 512));
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
key_stride,
|
||||
value_stride,
|
||||
block_table_stride
|
||||
);
|
||||
break;
|
||||
case 2:
|
||||
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
key_stride,
|
||||
value_stride,
|
||||
block_table_stride
|
||||
);
|
||||
break;
|
||||
case 4:
|
||||
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
key_stride,
|
||||
value_stride,
|
||||
block_table_stride
|
||||
);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||
break;
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
}
|
||||
|
||||
void decode_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
key.scalar_type(),
|
||||
"decode_kv_cache_memcpy",
|
||||
apply_decode_kv_cache_memcpy<scalar_t>(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
block_tables
|
||||
);)
|
||||
}
|
||||
|
472
extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu
Normal file
472
extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu
Normal file
@@ -0,0 +1,472 @@
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "../common/vector_copy_utils.h"
|
||||
#include "../common/micros.h"
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
__device__ void apply_emb_rotary_compute(
|
||||
scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr, const int64_t stride,
|
||||
const int token_id, const int shard_block_size, const int half_head_dim,
|
||||
const int head_num, const int head_dim) {
|
||||
scalar_t x[VecSize];
|
||||
scalar_t y[VecSize];
|
||||
scalar_t out_x[VecSize];
|
||||
scalar_t out_y[VecSize];
|
||||
|
||||
for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim;
|
||||
i += blockDim.x * VecSize) {
|
||||
const int head_offset = i % half_head_dim;
|
||||
const int shard_offset =
|
||||
(head_offset / shard_block_size) * shard_block_size +
|
||||
(head_offset % shard_block_size) / VecSize;
|
||||
const int64_t addr_offset =
|
||||
token_id * stride + (i / half_head_dim) * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(x, src + addr_offset);
|
||||
copy_vector<scalar_t, VecSize>(y, src + addr_offset + half_head_dim);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
|
||||
copy_vector<scalar_t, VecSize>(src + addr_offset + half_head_dim, out_y);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
__device__ void apply_kv_memcopy(
|
||||
scalar_t* __restrict__ src, scalar_t* __restrict__ cache,
|
||||
const int64_t stride, const int token_id, const int block_id,
|
||||
const int hidden_size, const int block_size, const int block_offset,
|
||||
const int head_dim, const int half_head_dim) {
|
||||
for (int i = threadIdx.x * VecSize; i < hidden_size / 2;
|
||||
i += blockDim.x * VecSize) {
|
||||
const int head_id = i / half_head_dim;
|
||||
const int head_offset = i % half_head_dim;
|
||||
const int64_t src_id = token_id * stride + head_id * head_dim + head_offset;
|
||||
const int64_t target_id = block_id * hidden_size * block_size +
|
||||
head_id * block_size * head_dim +
|
||||
block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(cache + target_id, src + src_id);
|
||||
copy_vector<scalar_t, VecSize>(cache + target_id + half_head_dim,
|
||||
src + src_id + half_head_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
__device__ void cos_sin_memory_access(
|
||||
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
|
||||
scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id,
|
||||
const int shard_block_size, const int cos_stride, const int sin_stride,
|
||||
const int half_head_dim) {
|
||||
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
|
||||
// We assume that the value of head_dim is less than 128*128.
|
||||
const int shard_offset = (i % shard_block_size) / VecSize;
|
||||
const int shard_head =
|
||||
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
|
||||
cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i];
|
||||
sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
__device__ void apply_k_rotary_emb_compute(
|
||||
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ block_tables, const int64_t key_stride,
|
||||
const int64_t value_stride, const int token_id,
|
||||
const int block_table_stride, const int head_num, const int head_dim,
|
||||
const int kv_head_num, const int block_size, const int half_head_dim,
|
||||
const int shard_block_size) {
|
||||
const int seq_len = sequence_lengths[token_id] - 1;
|
||||
const int block_offset = seq_len % block_size;
|
||||
const int block_id =
|
||||
block_tables[token_id * block_table_stride + seq_len / block_size];
|
||||
|
||||
if (block_id < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
scalar_t x[VecSize];
|
||||
scalar_t y[VecSize];
|
||||
scalar_t out_x[VecSize];
|
||||
scalar_t out_y[VecSize];
|
||||
|
||||
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
|
||||
i += blockDim.x * VecSize) {
|
||||
const int head_offset = i % half_head_dim;
|
||||
const int shard_offset =
|
||||
(head_offset / shard_block_size) * shard_block_size +
|
||||
(head_offset % shard_block_size) / VecSize;
|
||||
const int64_t addr_offset =
|
||||
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
|
||||
const int64_t target_id = block_id * head_num * head_dim * block_size +
|
||||
(i / half_head_dim) * block_size * head_dim +
|
||||
block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(x, key + addr_offset);
|
||||
copy_vector<scalar_t, VecSize>(y, key + addr_offset + half_head_dim);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim,
|
||||
out_y);
|
||||
}
|
||||
|
||||
// apply value memcopy
|
||||
apply_kv_memcopy<scalar_t, VecSize>(
|
||||
value, value_cache, value_stride, token_id, block_id, head_num * head_dim,
|
||||
block_size, block_offset, head_dim, half_head_dim);
|
||||
}
|
||||
|
||||
template<typename scalar_t, int VecSize>
|
||||
__global__ void rotary_embedding_and_cache_copy_kernel(
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
scalar_t* __restrict__ value,
|
||||
const scalar_t* __restrict__ cos,
|
||||
const scalar_t* __restrict__ sin,
|
||||
scalar_t* __restrict__ key_cache,
|
||||
scalar_t* __restrict__ value_cache,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ block_tables,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int64_t value_stride,
|
||||
const int64_t half_shard_element_num,
|
||||
const int cos_stride,
|
||||
const int sin_stride,
|
||||
const int block_table_stride,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int kv_head_num,
|
||||
const int block_size
|
||||
) {
|
||||
|
||||
const int token_id = blockIdx.x;
|
||||
const int half_head_dim = head_dim / 2;
|
||||
const int shard_block_size = VecSize * 32;
|
||||
|
||||
extern __shared__ char shard_ptr[];
|
||||
|
||||
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
|
||||
// apply cos_sin memcopy
|
||||
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
__syncthreads();
|
||||
|
||||
//compute query
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key and copy kv
|
||||
apply_k_rotary_emb_compute<scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
|
||||
}
|
||||
|
||||
template<typename scalar_t, int VecSize>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ cos,
|
||||
const scalar_t* __restrict__ sin,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int64_t half_shard_element_num,
|
||||
const int cos_stride,
|
||||
const int sin_stride,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int kv_head_num
|
||||
) {
|
||||
const int token_id = blockIdx.x;
|
||||
const int half_head_dim = head_dim / 2;
|
||||
const int shard_block_size = VecSize * 32;
|
||||
|
||||
extern __shared__ char shard_ptr[];
|
||||
|
||||
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
|
||||
// apply cos_sin memcopy
|
||||
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
__syncthreads();
|
||||
|
||||
//compute query
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void apply_rotary_embedding_and_cache_copy(
|
||||
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [num_tokens, head_dim]
|
||||
at::Tensor& sin, // [num_tokens, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
{
|
||||
int num_tokens = query.size(0);
|
||||
int head_num = query.size(1);
|
||||
int head_dim = query.size(2);
|
||||
int kv_head_num = key.size(1);
|
||||
int block_size = key_cache.size(2);
|
||||
|
||||
int64_t query_stride = query.stride(0);
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
int cos_stride = cos.stride(0);
|
||||
int sin_stride = sin.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(query);
|
||||
|
||||
if ((head_dim / 2) % vec_size != 0) {
|
||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||
vec_size = 1;
|
||||
}
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int thread_nums = head_num * head_dim / vec_size / 2;
|
||||
const int shard_block_size = vec_size * 32 * 2;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(thread_nums, 512));
|
||||
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
block_table_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num,
|
||||
block_size
|
||||
);
|
||||
break;
|
||||
case 2:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
block_table_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num,
|
||||
block_size
|
||||
);
|
||||
break;
|
||||
case 4:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
block_table_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num,
|
||||
block_size
|
||||
);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||
break;
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void apply_rotary_embedding(
|
||||
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [total_tokens, head_dim]
|
||||
at::Tensor& sin // [total_tokens, head_dim]
|
||||
){
|
||||
int num_tokens = query.size(0);
|
||||
int head_num = query.size(1);
|
||||
int head_dim = query.size(2);
|
||||
int kv_head_num = key.size(1);
|
||||
|
||||
int query_stride = query.stride(0);
|
||||
int key_stride = key.stride(0);
|
||||
int cos_stride = cos.stride(0);
|
||||
int sin_stride = sin.stride(0);
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(query);
|
||||
|
||||
if ((head_dim / 2) % vec_size != 0) {
|
||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||
vec_size = 1;
|
||||
}
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int thread_nums = head_num * head_dim / vec_size / 2;
|
||||
const int shard_block_size = vec_size * 32 * 2;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(thread_nums, 512));
|
||||
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
rotary_embedding_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num
|
||||
);
|
||||
break;
|
||||
case 2:
|
||||
rotary_embedding_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num
|
||||
);
|
||||
break;
|
||||
case 4:
|
||||
rotary_embedding_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
sin.data_ptr<scalar_t>(),
|
||||
query_stride,
|
||||
key_stride,
|
||||
shard_element_num / 2,
|
||||
cos_stride,
|
||||
sin_stride,
|
||||
head_num,
|
||||
head_dim,
|
||||
kv_head_num
|
||||
);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||
break;
|
||||
}
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void rotary_embedding_and_cache_copy(
|
||||
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [num_tokens, head_dim]
|
||||
at::Tensor& sin, // [num_tokens, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_and_cache_copy",
|
||||
apply_rotary_embedding_and_cache_copy<scalar_t>(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cos,
|
||||
sin,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
block_tables
|
||||
);)
|
||||
}
|
||||
|
||||
void rotary_embedding(
|
||||
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [total_tokens, head_dim]
|
||||
at::Tensor& sin // [total_tokens, head_dim]
|
||||
){
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
apply_rotary_embedding<scalar_t>(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin
|
||||
);)
|
||||
}
|
@@ -9,6 +9,23 @@ void decode_kv_cache_memcpy(
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& cos, // [total_tokens, head_dim]
|
||||
torch::Tensor& sin); // [total_tokens, head_dim]
|
||||
|
||||
void rotary_embedding_and_cache_copy(
|
||||
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
|
||||
torch::Tensor& cos, // [num_tokens, head_dim]
|
||||
torch::Tensor& sin, // [num_tokens, head_dim]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||
|
||||
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||
@@ -25,6 +42,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
|
||||
m.def(
|
||||
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
|
||||
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
|
||||
|
||||
m.def("rotary_embedding", &rotary_embedding,
|
||||
"performing Rotary Embedding-related calculations.");
|
||||
|
||||
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||
|
||||
m.def("rms_layernorm", &rms_layernorm,
|
||||
|
@@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
||||
for fname in [
|
||||
"cuda/pybind/inference.cpp",
|
||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||
"cuda/fused_rotary_emb_and_cache_kernel.cu",
|
||||
"cuda/activation_kernel.cu",
|
||||
"cuda/rms_layernorm_kernel.cu",
|
||||
]
|
||||
|
Reference in New Issue
Block a user