mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)
* Support FP16/BF16 Flash Attention 2 * fix bugs in test_kv_cache_memcpy.py * add context_kv_cache_memcpy_kernel.cu * rm typename MT * add tail process * add high_precision * add high_precision to config.py * rm unused code * change the comment for the high_precision parameter * update test_rotary_embdding_unpad.py * fix vector_copy_utils.h * add comment for self.high_precision when using float32
This commit is contained in:
@@ -56,6 +56,23 @@
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
|
||||
TYPE, NAME, ...) \
|
||||
switch (HIGH_PRECISION) { \
|
||||
case false: { \
|
||||
const bool high_precision = false; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
break; \
|
||||
} \
|
||||
case true: { \
|
||||
const bool high_precision = true; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch (TYPEIN) { \
|
||||
case at::ScalarType::Float: { \
|
||||
|
@@ -27,5 +27,18 @@ struct MPTypeTrait<at::BFloat16> {
|
||||
using Type = float;
|
||||
};
|
||||
|
||||
template <bool high_precision, typename scalar_t>
|
||||
struct ScalarTypeTrait;
|
||||
|
||||
template <typename T>
|
||||
struct ScalarTypeTrait<true, T> {
|
||||
using Type = typename MPTypeTrait<T>::Type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScalarTypeTrait<false, T> {
|
||||
using Type = T;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace colossalAI
|
||||
|
195
extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu
Normal file
195
extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu
Normal file
@@ -0,0 +1,195 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "utils/vector_copy_utils.h"
|
||||
#include "../common/micros.h"
|
||||
|
||||
template<typename scalar_t, int VecSize>
|
||||
__global__ void context_kv_cache_memcpy_kernel(
|
||||
const scalar_t* __restrict__ key,
|
||||
const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache,
|
||||
scalar_t* __restrict__ value_cache,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ cu_seqlens,
|
||||
const int* __restrict__ block_tables,
|
||||
const int head_num,
|
||||
const int head_dim,
|
||||
const int block_size,
|
||||
const int batch_size,
|
||||
const int block_table_stride,
|
||||
const int64_t key_stride,
|
||||
const int64_t value_stride
|
||||
)
|
||||
{
|
||||
const int seq_token_id = blockIdx.x;
|
||||
const int seq_id = blockIdx.y;
|
||||
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];
|
||||
|
||||
if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
|
||||
return ;
|
||||
}
|
||||
|
||||
const int block_offset = seq_token_id % block_size;
|
||||
const int hidden_size = head_num * head_dim;
|
||||
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
|
||||
int head_id;
|
||||
int head_offset;
|
||||
int64_t key_src_id;
|
||||
int64_t value_src_id;
|
||||
int64_t target_id;
|
||||
|
||||
int i = threadIdx.x * VecSize;
|
||||
|
||||
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
|
||||
head_id = i / head_dim;
|
||||
head_offset = i % head_dim;
|
||||
key_src_id = total_token_id * key_stride + i;
|
||||
value_src_id = total_token_id * value_stride + i;
|
||||
target_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||
}
|
||||
|
||||
// tail process
|
||||
for (; i < hidden_size; ++i ) {
|
||||
head_id = i / head_dim;
|
||||
head_offset = i % head_dim;
|
||||
key_src_id = total_token_id * key_stride + i;
|
||||
value_src_id = total_token_id * value_stride + i;
|
||||
target_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = key[key_src_id];
|
||||
value_cache[target_id] = value[value_src_id];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void apply_context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int head_num = key.size(1);
|
||||
int head_dim = key.size(2);
|
||||
int block_size = key_cache.size(2);
|
||||
int batch_size = block_tables.size(0);
|
||||
|
||||
int64_t key_stride = key.stride(0);
|
||||
int64_t value_stride = value.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(key);
|
||||
|
||||
if (head_dim % vec_size != 0) {
|
||||
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||
vec_size = 1;
|
||||
}
|
||||
|
||||
int thread_nums = head_num * head_dim / vec_size;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(max_seq_len_in_batch, batch_size);
|
||||
dim3 block(std::min(thread_nums, 512));
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
context_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
cu_seqlens.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
batch_size,
|
||||
block_table_stride,
|
||||
key_stride,
|
||||
value_stride
|
||||
);
|
||||
break;
|
||||
case 2:
|
||||
context_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
cu_seqlens.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
batch_size,
|
||||
block_table_stride,
|
||||
key_stride,
|
||||
value_stride
|
||||
);
|
||||
break;
|
||||
case 4:
|
||||
context_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
sequence_lengths.data_ptr<int>(),
|
||||
cu_seqlens.data_ptr<int>(),
|
||||
block_tables.data_ptr<int>(),
|
||||
head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
batch_size,
|
||||
block_table_stride,
|
||||
key_stride,
|
||||
value_stride
|
||||
);
|
||||
break;
|
||||
default:
|
||||
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||
break;
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
}
|
||||
|
||||
void context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch)
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
key.scalar_type(),
|
||||
"context_kv_cache_memcpy",
|
||||
apply_context_kv_cache_memcpy<scalar_t>(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
sequence_lengths,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_seq_len_in_batch
|
||||
);)
|
||||
}
|
@@ -30,7 +30,9 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
||||
return ;
|
||||
}
|
||||
|
||||
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
|
||||
int i = threadIdx.x * VecSize;
|
||||
|
||||
for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
@@ -43,6 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
||||
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||
}
|
||||
|
||||
for (; i < hidden_size; ++i ) {
|
||||
const int head_id = i / head_dim;
|
||||
const int head_offset = i % head_dim;
|
||||
const int64_t key_src_id = seq_id * key_stride + i;
|
||||
const int64_t value_src_id = seq_id * value_stride + i;
|
||||
const int64_t target_id = block_id * hidden_size * block_size
|
||||
+ head_id * block_size * head_dim
|
||||
+ block_offset * head_dim + head_offset;
|
||||
|
||||
key_cache[target_id] = key[key_src_id];
|
||||
value_cache[target_id] = value[value_src_id];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
|
@@ -1,14 +1,15 @@
|
||||
|
||||
// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "utils/vector_copy_utils.h"
|
||||
#include "../common/micros.h"
|
||||
#include "../common/mp_type_traits.h"
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__device__ void apply_emb_rotary_compute(
|
||||
scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr, const int64_t stride,
|
||||
scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr,
|
||||
const m_scalar_t* __restrict__ sin_ptr, const int64_t stride,
|
||||
const int token_id, const int shard_block_size, const int half_head_dim,
|
||||
const int head_num, const int head_dim) {
|
||||
scalar_t x[VecSize];
|
||||
@@ -30,10 +31,10 @@ __device__ void apply_emb_rotary_compute(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
|
||||
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
|
||||
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
|
||||
@@ -62,10 +63,10 @@ __device__ void apply_kv_memcopy(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__device__ void cos_sin_memory_access(
|
||||
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
|
||||
scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id,
|
||||
m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id,
|
||||
const int shard_block_size, const int cos_stride, const int sin_stride,
|
||||
const int half_head_dim) {
|
||||
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
|
||||
@@ -73,16 +74,16 @@ __device__ void cos_sin_memory_access(
|
||||
const int shard_offset = (i % shard_block_size) / VecSize;
|
||||
const int shard_head =
|
||||
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
|
||||
cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i];
|
||||
sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i];
|
||||
cos_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(cos[token_id * cos_stride + i]);
|
||||
sin_ptr[shard_head + shard_offset] = static_cast<m_scalar_t>(sin[token_id * sin_stride + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int VecSize>
|
||||
template <typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__device__ void apply_k_rotary_emb_compute(
|
||||
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr,
|
||||
const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr,
|
||||
const int* __restrict__ sequence_lengths,
|
||||
const int* __restrict__ block_tables, const int64_t key_stride,
|
||||
const int64_t value_stride, const int token_id,
|
||||
@@ -120,10 +121,10 @@ __device__ void apply_k_rotary_emb_compute(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < VecSize; j++) {
|
||||
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||
out_x[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(x[j]) * cos_ptr[j * 32 + shard_offset] -
|
||||
static_cast<m_scalar_t>(y[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
out_y[j] = static_cast<scalar_t>(static_cast<m_scalar_t>(y[j]) * cos_ptr[j * 32 + shard_offset] +
|
||||
static_cast<m_scalar_t>(x[j]) * sin_ptr[j * 32 + shard_offset]);
|
||||
}
|
||||
|
||||
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
|
||||
@@ -137,7 +138,7 @@ __device__ void apply_k_rotary_emb_compute(
|
||||
block_size, block_offset, head_dim, half_head_dim);
|
||||
}
|
||||
|
||||
template<typename scalar_t, int VecSize>
|
||||
template<typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__global__ void rotary_embedding_and_cache_copy_kernel(
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
@@ -167,21 +168,21 @@ __global__ void rotary_embedding_and_cache_copy_kernel(
|
||||
|
||||
extern __shared__ char shard_ptr[];
|
||||
|
||||
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
|
||||
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
|
||||
// apply cos_sin memcopy
|
||||
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
__syncthreads();
|
||||
|
||||
//compute query
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key and copy kv
|
||||
apply_k_rotary_emb_compute<scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
|
||||
apply_k_rotary_emb_compute<scalar_t, m_scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
|
||||
}
|
||||
|
||||
template<typename scalar_t, int VecSize>
|
||||
template<typename scalar_t, typename m_scalar_t, int VecSize>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
@@ -202,21 +203,21 @@ __global__ void rotary_embedding_kernel(
|
||||
|
||||
extern __shared__ char shard_ptr[];
|
||||
|
||||
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr;
|
||||
m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||
|
||||
// apply cos_sin memcopy
|
||||
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
cos_sin_memory_access<scalar_t, m_scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||
__syncthreads();
|
||||
|
||||
//compute query
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||
|
||||
//compute key
|
||||
apply_emb_rotary_compute<scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||
apply_emb_rotary_compute<scalar_t, m_scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename scalar_t, bool high_precision>
|
||||
void apply_rotary_embedding_and_cache_copy(
|
||||
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||
@@ -241,6 +242,8 @@ void apply_rotary_embedding_and_cache_copy(
|
||||
int sin_stride = sin.stride(0);
|
||||
int block_table_stride = block_tables.stride(0);
|
||||
|
||||
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(query);
|
||||
|
||||
if ((head_dim / 2) % vec_size != 0) {
|
||||
@@ -259,7 +262,7 @@ void apply_rotary_embedding_and_cache_copy(
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
@@ -283,7 +286,7 @@ void apply_rotary_embedding_and_cache_copy(
|
||||
);
|
||||
break;
|
||||
case 2:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
@@ -307,7 +310,7 @@ void apply_rotary_embedding_and_cache_copy(
|
||||
);
|
||||
break;
|
||||
case 4:
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_and_cache_copy_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
@@ -338,12 +341,12 @@ void apply_rotary_embedding_and_cache_copy(
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename scalar_t, bool high_precision>
|
||||
void apply_rotary_embedding(
|
||||
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [total_tokens, head_dim]
|
||||
at::Tensor& sin // [total_tokens, head_dim]
|
||||
at::Tensor& sin // [total_tokens, head_dim]
|
||||
){
|
||||
int num_tokens = query.size(0);
|
||||
int head_num = query.size(1);
|
||||
@@ -355,6 +358,8 @@ void apply_rotary_embedding(
|
||||
int cos_stride = cos.stride(0);
|
||||
int sin_stride = sin.stride(0);
|
||||
|
||||
using m_scalar_t = typename colossalAI::common::ScalarTypeTrait<high_precision, scalar_t>::Type;
|
||||
|
||||
int vec_size = get_vec_size<scalar_t>(query);
|
||||
|
||||
if ((head_dim / 2) % vec_size != 0) {
|
||||
@@ -373,7 +378,7 @@ void apply_rotary_embedding(
|
||||
|
||||
switch (vec_size) {
|
||||
case 1:
|
||||
rotary_embedding_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 1><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
@@ -389,7 +394,7 @@ void apply_rotary_embedding(
|
||||
);
|
||||
break;
|
||||
case 2:
|
||||
rotary_embedding_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 2><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
@@ -405,7 +410,7 @@ void apply_rotary_embedding(
|
||||
);
|
||||
break;
|
||||
case 4:
|
||||
rotary_embedding_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||
rotary_embedding_kernel<scalar_t, m_scalar_t, 4><<<grid, block, shard_element_num * sizeof(m_scalar_t), stream>>>(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos.data_ptr<scalar_t>(),
|
||||
@@ -436,12 +441,14 @@ void rotary_embedding_and_cache_copy(
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
bool high_precision)
|
||||
{
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
|
||||
high_precision,
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_and_cache_copy",
|
||||
apply_rotary_embedding_and_cache_copy<scalar_t>(
|
||||
apply_rotary_embedding_and_cache_copy<scalar_t, high_precision>(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
@@ -458,12 +465,14 @@ void rotary_embedding(
|
||||
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
at::Tensor& cos, // [total_tokens, head_dim]
|
||||
at::Tensor& sin // [total_tokens, head_dim]
|
||||
at::Tensor& sin, // [total_tokens, head_dim]
|
||||
bool high_precision
|
||||
){
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(
|
||||
high_precision,
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
apply_rotary_embedding<scalar_t>(
|
||||
apply_rotary_embedding<scalar_t, high_precision>(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
|
@@ -9,11 +9,22 @@ void decode_kv_cache_memcpy(
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
|
||||
void context_kv_cache_memcpy(
|
||||
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||
at::Tensor& sequence_lengths, // [batch_size]
|
||||
at::Tensor& cu_seqlens, // [batch_size + 1]
|
||||
at::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
int max_seq_len_in_batch);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||
torch::Tensor& cos, // [total_tokens, head_dim]
|
||||
torch::Tensor& sin); // [total_tokens, head_dim]
|
||||
torch::Tensor& sin, // [total_tokens, head_dim]
|
||||
bool high_precision);
|
||||
|
||||
void rotary_embedding_and_cache_copy(
|
||||
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||
@@ -25,7 +36,9 @@ void rotary_embedding_and_cache_copy(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||
torch::Tensor& sequence_lengths, // [batch_size]
|
||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||
torch::Tensor& block_tables, // [batch_size, max_seq_len]
|
||||
bool high_precision);
|
||||
|
||||
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||
|
||||
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||
@@ -42,6 +55,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the decode stage.");
|
||||
|
||||
m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy,
|
||||
"Copy the GPU memory of kvcache during the context stage.");
|
||||
|
||||
m.def(
|
||||
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
|
||||
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
|
||||
|
@@ -11,6 +11,8 @@
|
||||
#include <cfloat>
|
||||
#include <limits>
|
||||
|
||||
#include "utils/vector_copy_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
int log2_ceil(int value) {
|
||||
|
@@ -11,16 +11,16 @@ template <typename T, int VecSize>
|
||||
__device__ __inline__ void copy_vector(T *dst, const T *src) {
|
||||
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
|
||||
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<VT *>(src));
|
||||
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<const VT *>(src));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
||||
// Since the maximum memory alignment length is 128 bits, we choose float4
|
||||
// here.
|
||||
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<float4 *>(src));
|
||||
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<const float4 *>(src));
|
||||
*(reinterpret_cast<float4 *>(dst + 4)) =
|
||||
*(reinterpret_cast<float4 *>(src + 4));
|
||||
*(reinterpret_cast<const float4 *>(src + 4));
|
||||
}
|
||||
|
||||
template <typename T, int VecSize>
|
||||
|
@@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
||||
for fname in [
|
||||
"cuda/pybind/inference.cpp",
|
||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||
"cuda/context_kv_cache_memcpy_kernel.cu",
|
||||
"cuda/fused_rotary_emb_and_cache_kernel.cu",
|
||||
"cuda/activation_kernel.cu",
|
||||
"cuda/rms_layernorm_kernel.cu",
|
||||
|
Reference in New Issue
Block a user