diff --git a/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu index 03682187e..19ea5bb8a 100644 --- a/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/kernel/cuda/decode_kv_cache_memcpy_kernel.cu @@ -2,17 +2,21 @@ #include #include "utils/vec_copy.h" +#include "funcs/cast_functor.h" #include "common/micros.h" using colossalAI::cuda::utils::copy_vector; using colossalAI::cuda::utils::get_vec_size; +using colossalAI::cuda::utils::copy; +using colossalAI::funcs::CastFunctor; -template + +template __global__ void decode_kv_cache_memcpy_kernel( - const scalar_t* __restrict__ key, - const scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, - scalar_t* __restrict__ value_cache, + const T* __restrict__ key, + const T* __restrict__ value, + CacheT* __restrict__ key_cache, + CacheT* __restrict__ value_cache, const int* __restrict__ sequence_lengths, const int* __restrict__ block_tables, const int head_num, @@ -52,8 +56,8 @@ __global__ void decode_kv_cache_memcpy_kernel( + head_id * block_size * head_dim + block_offset * head_dim + head_offset; - copy_vector(key_cache + target_key_id, key + key_src_id); - copy_vector(value_cache + target_value_id, value + value_src_id); + copy(key + key_src_id, key_cache + target_key_id); + copy(value + value_src_id, value_cache + target_value_id); } if (!Aligned) { @@ -73,14 +77,14 @@ __global__ void decode_kv_cache_memcpy_kernel( + head_id * block_size * head_dim + block_offset * head_dim + head_offset; - key_cache[target_key_id] = key[key_src_id]; - value_cache[target_value_id] = value[value_src_id]; + key_cache[target_key_id] = CastFunctor()(key[key_src_id]); + value_cache[target_value_id] = CastFunctor()(value[value_src_id]); } } } -template +template void apply_decode_kv_cache_memcpy( at::Tensor& key, // [num_tokens, head_num, head_dim] at::Tensor& value, // [num_tokens, head_num, head_dim] @@ -99,7 +103,7 @@ void apply_decode_kv_cache_memcpy( int64_t value_stride = value.stride(0); int block_table_stride = block_tables.stride(0); - int vec_size = get_vec_size(key); + int vec_size = get_vec_size(key); bool aligned = true; if (head_dim % vec_size != 0) { @@ -114,11 +118,11 @@ void apply_decode_kv_cache_memcpy( #define DECODE_KV_CACHE_MEMCOPY_KERNEL_LAUNCH(__aligned, __vec_size) \ do { \ - decode_kv_cache_memcpy_kernel<<>>( \ - key.data_ptr(), \ - value.data_ptr(), \ - key_cache.data_ptr(), \ - value_cache.data_ptr(), \ + decode_kv_cache_memcpy_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ sequence_lengths.data_ptr(), \ block_tables.data_ptr(), \ head_num, \ @@ -168,15 +172,46 @@ void decode_kv_cache_memcpy( at::Tensor& sequence_lengths, // [batch_size] at::Tensor& block_tables) // [batch_size, max_seq_len] { - DISPATCH_FLOAT_HALF_AND_BFLOAT( - key.scalar_type(), - "decode_kv_cache_memcpy", - apply_decode_kv_cache_memcpy( - key, - value, - key_cache, - value_cache, - sequence_lengths, - block_tables - );) + +#define _(T, CacheT) \ + apply_decode_kv_cache_memcpy( \ + key, \ + value, \ + key_cache, \ + value_cache, \ + sequence_lengths, \ + block_tables \ + ) + + if(key_cache.scalar_type() == at::ScalarType::Byte) + { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, uint8_t); + break; + case at::ScalarType::Half: + _(half, uint8_t); + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, uint8_t); + break; + } + } + else + { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, float); + break; + case at::ScalarType::Half: + _(half, half); + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, __nv_bfloat16); + break; + } + } +#undef _ }