From ef8e4ffe310bfe21f83feb965d962d816d75bc88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 30 Apr 2024 18:33:53 +0800 Subject: [PATCH] [Inference/Feat] Add kvcache quant support for fused_rotary_embedding_cache_copy (#5680) --- extensions/csrc/common/mp_type_traits.h | 17 + extensions/csrc/funcs/binary_functor.h | 19 ++ extensions/csrc/funcs/cast_functor.h | 4 + .../cuda/context_kv_cache_memcpy_kernel.cu | 6 - .../cuda/flash_decoding_attention_kernel.cu | 6 - .../cuda/fused_rotary_emb_and_cache_kernel.cu | 294 +++++++++++------- extensions/csrc/kernel/cuda/utils/vec_copy.h | 5 +- 7 files changed, 226 insertions(+), 125 deletions(-) diff --git a/extensions/csrc/common/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h index 527573219..7a27f2650 100644 --- a/extensions/csrc/common/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -4,6 +4,11 @@ #include "micros.h" +#if defined(COLOSSAL_WITH_CUDA) +#include +#include +#endif + namespace colossalAI { namespace common { @@ -27,6 +32,18 @@ struct MPTypeTrait { using Type = float; }; +#if defined(COLOSSAL_WITH_CUDA) +template <> +struct MPTypeTrait { + using Type = float; +}; + +template <> +struct MPTypeTrait<__nv_bfloat16> { + using Type = float; +}; +#endif + template struct ScalarTypeTrait { using Type = diff --git a/extensions/csrc/funcs/binary_functor.h b/extensions/csrc/funcs/binary_functor.h index c5fe48076..822f131c2 100644 --- a/extensions/csrc/funcs/binary_functor.h +++ b/extensions/csrc/funcs/binary_functor.h @@ -56,6 +56,11 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE, typename T) #if defined(COLOSSAL_WITH_CUDA) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMinus, + DEVICE, STMTS_WRAPPER({ + return __hsub(lhs, rhs); + })) + COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd, DEVICE, STMTS_WRAPPER({ return __hadd(lhs, rhs); @@ -71,6 +76,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16, DEVICE, STMTS_WRAPPER({ return __hadd(lhs, rhs); })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16, + __nv_bfloat16, BinaryOpType::kMinus, + DEVICE, STMTS_WRAPPER({ + return __hsub(lhs, rhs); + })) + COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE, STMTS_WRAPPER({ @@ -82,6 +94,13 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( STMTS_WRAPPER({ return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs)); })) + +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( + __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMinus, DEVICE, + STMTS_WRAPPER({ + return __float2bfloat16(__bfloat162float(lhs) - __bfloat162float(rhs)); + })) + COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( __nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE, STMTS_WRAPPER({ diff --git a/extensions/csrc/funcs/cast_functor.h b/extensions/csrc/funcs/cast_functor.h index d9691d870..6382d5271 100644 --- a/extensions/csrc/funcs/cast_functor.h +++ b/extensions/csrc/funcs/cast_functor.h @@ -94,6 +94,10 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, __nv_bfloat16, DEVICE, STMTS_WRAPPER({ return __float2bfloat16_rn(val); })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, float, DEVICE, + STMTS_WRAPPER({ + return __bfloat162float(val); + })) COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE, STMTS_WRAPPER({ dtype::bfloat164 dst; diff --git a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu index 473324f45..e9b7738b0 100644 --- a/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu @@ -192,12 +192,6 @@ void context_kv_cache_memcpy( int max_seq_len_in_batch) { - TORCH_CHECK(key.scalar_type() == at::ScalarType::Float || key.scalar_type() == at::ScalarType::Half || key.scalar_type() == at::ScalarType::BFloat16, - "Dtype of key should be float, half or bfloat16!"); - TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key.scalar_type(), - "Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!"); - - #define _(T, CacheT) \ apply_context_kv_cache_memcpy( \ key, \ diff --git a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu index 110907435..bcea786fe 100644 --- a/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu +++ b/extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu @@ -380,12 +380,6 @@ void flash_decoding_attention( const c10::optional& alibi_slopes, float scale) { - - TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16, - "Dtype of query should be float, half or bfloat16!"); - TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == query.scalar_type(), - "Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!"); - if(key_cache.scalar_type() == at::ScalarType::Byte) { switch (query.scalar_type()) { diff --git a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu index 7a2629171..68b47c7e9 100644 --- a/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -5,20 +5,30 @@ #include "utils/vec_copy.h" #include "common/micros.h" #include "common/mp_type_traits.h" +#include "funcs/cast_functor.h" +#include "funcs/binary_functor.h" using colossalAI::cuda::utils::copy_vector; using colossalAI::cuda::utils::get_vec_size; +using colossalAI::cuda::utils::copy; +using colossalAI::funcs::CastFunctor; +using colossalAI::funcs::BinaryOpFunctor; +using colossalAI::funcs::BinaryOpType; -template +template __device__ void apply_emb_rotary_compute( - scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr, - const m_scalar_t* __restrict__ sin_ptr, const int64_t stride, + T* __restrict__ src, const MT* __restrict__ cos_ptr, + const MT* __restrict__ sin_ptr, const int64_t stride, const int token_id, const int shard_block_size, const int half_head_dim, const int head_num, const int head_dim) { - scalar_t x[VecSize]; - scalar_t y[VecSize]; - scalar_t out_x[VecSize]; - scalar_t out_y[VecSize]; + BinaryOpFunctor mul; + BinaryOpFunctor sub; + BinaryOpFunctor add; + + T x[VecSize]; + T y[VecSize]; + T out_x[VecSize]; + T out_y[VecSize]; for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim; i += blockDim.x * VecSize) { @@ -29,25 +39,25 @@ __device__ void apply_emb_rotary_compute( const int64_t addr_offset = token_id * stride + (i / half_head_dim) * head_dim + head_offset; - copy_vector(x, src + addr_offset); - copy_vector(y, src + addr_offset + half_head_dim); + copy(src + addr_offset, x); + copy(src + addr_offset + half_head_dim, y); #pragma unroll for (int j = 0; j < VecSize; j++) { - out_x[j] = static_cast(static_cast(x[j]) * cos_ptr[j * 32 + shard_offset] - - static_cast(y[j]) * sin_ptr[j * 32 + shard_offset]); - out_y[j] = static_cast(static_cast(y[j]) * cos_ptr[j * 32 + shard_offset] + - static_cast(x[j]) * sin_ptr[j * 32 + shard_offset]); + out_x[j] = CastFunctor()(sub(mul(CastFunctor()(x[j]), cos_ptr[j * 32 + shard_offset]), + mul(CastFunctor()(y[j]), sin_ptr[j * 32 + shard_offset]))); + out_y[j] = CastFunctor()(add(mul(CastFunctor()(y[j]), cos_ptr[j * 32 + shard_offset]), + mul(CastFunctor()(x[j]), sin_ptr[j * 32 + shard_offset]))); } - copy_vector(src + addr_offset, out_x); - copy_vector(src + addr_offset + half_head_dim, out_y); + copy(out_x, src + addr_offset); + copy(out_y, src + addr_offset + half_head_dim); } } -template +template __device__ void apply_kv_memcopy( - scalar_t* __restrict__ src, scalar_t* __restrict__ cache, + T* __restrict__ src, CacheT* __restrict__ cache, const int64_t stride, const int token_id, const int block_id, const int hidden_size, const int block_size, const int block_offset, const int head_dim, const int half_head_dim) { @@ -60,16 +70,15 @@ __device__ void apply_kv_memcopy( head_id * block_size * head_dim + block_offset * head_dim + head_offset; - copy_vector(cache + target_id, src + src_id); - copy_vector(cache + target_id + half_head_dim, - src + src_id + half_head_dim); + copy(src + src_id, cache + target_id); + copy(src + src_id + half_head_dim, cache + target_id + half_head_dim); } } -template +template __device__ void cos_sin_memory_access( - const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin, - m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id, + const T* __restrict__ cos, const T* __restrict__ sin, + MT* cos_ptr, MT* sin_ptr, const int token_id, const int shard_block_size, const int cos_stride, const int sin_stride, const int half_head_dim) { for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { @@ -77,22 +86,26 @@ __device__ void cos_sin_memory_access( const int shard_offset = (i % shard_block_size) / VecSize; const int shard_head = (i / shard_block_size) * shard_block_size + i % VecSize * 32; - cos_ptr[shard_head + shard_offset] = static_cast(cos[token_id * cos_stride + i]); - sin_ptr[shard_head + shard_offset] = static_cast(sin[token_id * sin_stride + i]); + cos_ptr[shard_head + shard_offset] = CastFunctor()(cos[token_id * cos_stride + i]); + sin_ptr[shard_head + shard_offset] = CastFunctor()(sin[token_id * sin_stride + i]); } } -template +template __device__ void apply_k_rotary_emb_compute( - scalar_t* __restrict__ key, scalar_t* __restrict__ value, - scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr, + T* __restrict__ key, T* __restrict__ value, + CacheT* __restrict__ key_cache, CacheT* __restrict__ value_cache, + const MT* __restrict__ cos_ptr, const MT* __restrict__ sin_ptr, const int* __restrict__ sequence_lengths, const int* __restrict__ block_tables, const int64_t key_stride, const int64_t value_stride, const int token_id, const int block_table_stride, const int head_num, const int head_dim, const int kv_head_num, const int block_size, const int x, const int half_head_dim, const int shard_block_size) { + + BinaryOpFunctor mul; + BinaryOpFunctor sub; + BinaryOpFunctor add; const int seq_len = sequence_lengths[token_id] - 1; const int block_offset = seq_len % block_size; const int block_id = @@ -102,10 +115,10 @@ __device__ void apply_k_rotary_emb_compute( return; } - scalar_t x0[VecSize]; - scalar_t x1[VecSize]; - scalar_t out_x[VecSize]; - scalar_t out_y[VecSize]; + T x0[VecSize]; + T x1[VecSize]; + T out_x[VecSize]; + T out_y[VecSize]; for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim; i += blockDim.x * VecSize) { @@ -123,37 +136,36 @@ __device__ void apply_k_rotary_emb_compute( + block_offset * x + x_offset; - copy_vector(x0, key + addr_offset); - copy_vector(x1, key + addr_offset + half_head_dim); + copy(key + addr_offset, x0); + copy(key + addr_offset + half_head_dim, x1); #pragma unroll for (int j = 0; j < VecSize; j++) { - out_x[j] = static_cast(static_cast(x0[j]) * cos_ptr[j * 32 + shard_offset] - - static_cast(x1[j]) * sin_ptr[j * 32 + shard_offset]); - out_y[j] = static_cast(static_cast(x1[j]) * cos_ptr[j * 32 + shard_offset] + - static_cast(x0[j]) * sin_ptr[j * 32 + shard_offset]); + out_x[j] = CastFunctor()(sub(mul(CastFunctor()(x0[j]), cos_ptr[j * 32 + shard_offset]), + mul(CastFunctor()(x1[j]), sin_ptr[j * 32 + shard_offset]))); + out_y[j] = CastFunctor()(add(mul(CastFunctor()(x1[j]), cos_ptr[j * 32 + shard_offset]), + mul(CastFunctor()(x0[j]), sin_ptr[j * 32 + shard_offset]))); } - copy_vector(key_cache + target_id, out_x); - copy_vector(key_cache + target_id + half_head_dim * block_size, - out_y); + copy(out_x, key_cache + target_id); + copy(out_y, key_cache + target_id + half_head_dim * block_size); } // apply value memcopy - apply_kv_memcopy( + apply_kv_memcopy( value, value_cache, value_stride, token_id, block_id, kv_head_num * head_dim, block_size, block_offset, head_dim, half_head_dim); } -template +template __global__ void rotary_embedding_and_cache_copy_kernel( - scalar_t* __restrict__ query, - scalar_t* __restrict__ key, - scalar_t* __restrict__ value, - const scalar_t* __restrict__ cos, - const scalar_t* __restrict__ sin, - scalar_t* __restrict__ key_cache, - scalar_t* __restrict__ value_cache, + T* __restrict__ query, + T* __restrict__ key, + T* __restrict__ value, + const T* __restrict__ cos, + const T* __restrict__ sin, + CacheT* __restrict__ key_cache, + CacheT* __restrict__ value_cache, const int* __restrict__ sequence_lengths, const int* __restrict__ block_tables, const int64_t query_stride, @@ -176,26 +188,26 @@ __global__ void rotary_embedding_and_cache_copy_kernel( extern __shared__ char shard_ptr[]; - m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; - m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + MT *cos_ptr = reinterpret_cast(shard_ptr); + MT *sin_ptr = cos_ptr + half_shard_element_num; // apply cos_sin memcopy - cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); __syncthreads(); //compute query - apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); //compute key and copy kv - apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size); + apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, x, half_head_dim, shard_block_size); } -template +template __global__ void rotary_embedding_kernel( - scalar_t* __restrict__ query, - scalar_t* __restrict__ key, - const scalar_t* __restrict__ cos, - const scalar_t* __restrict__ sin, + T* __restrict__ query, + T* __restrict__ key, + const T* __restrict__ cos, + const T* __restrict__ sin, const int64_t query_stride, const int64_t key_stride, const int64_t half_shard_element_num, @@ -211,29 +223,29 @@ __global__ void rotary_embedding_kernel( extern __shared__ char shard_ptr[]; - m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; - m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + MT *cos_ptr = (MT*)shard_ptr; + MT *sin_ptr = cos_ptr + half_shard_element_num; // apply cos_sin memcopy - cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); __syncthreads(); //compute query - apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); //compute key - apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); + apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); } #define ROTARY_EMBEDDING_AND_CACHE_COPY_LAUNCHER(VEC_SIZE) \ - rotary_embedding_and_cache_copy_kernel<<>>( \ - query.data_ptr(), \ - key.data_ptr(), \ - value.data_ptr(), \ - cos.data_ptr(), \ - sin.data_ptr(), \ - key_cache.data_ptr(), \ - value_cache.data_ptr(), \ + rotary_embedding_and_cache_copy_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(cos.data_ptr()), \ + reinterpret_cast(sin.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ sequence_lengths.data_ptr(), \ block_tables.data_ptr(), \ query_stride, \ @@ -250,7 +262,7 @@ __global__ void rotary_embedding_kernel( x); \ -template +template void apply_rotary_embedding_and_cache_copy( at::Tensor& query, // [num_tokens, head_num, head_dim] at::Tensor& key, // [num_tokens, kv_head_num, head_dim] @@ -276,9 +288,9 @@ void apply_rotary_embedding_and_cache_copy( int sin_stride = sin.stride(0); int block_table_stride = block_tables.stride(0); - using m_scalar_t = typename colossalAI::common::ScalarTypeTrait::Type; + using MT = typename colossalAI::common::ScalarTypeTrait::Type; - int vec_size = get_vec_size(query); + int vec_size = get_vec_size(query); if ((head_dim / 2) % vec_size != 0) { // Disable vectorized loading optimization when head_dim is not divisible by VecSize. @@ -293,7 +305,7 @@ void apply_rotary_embedding_and_cache_copy( dim3 grid(num_tokens); dim3 block(std::min(thread_nums, 512)); int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size; - const int shared_memory_size = shard_element_num * sizeof(m_scalar_t); + const int shared_memory_size = shard_element_num * sizeof(MT); switch (vec_size) { case 1: @@ -313,7 +325,7 @@ void apply_rotary_embedding_and_cache_copy( AT_CUDA_CHECK(cudaGetLastError()); } -template +template void apply_rotary_embedding( at::Tensor& query, // [total_tokens, head_num, head_dim] at::Tensor& key, // [total_tokens, kv_head_num, head_dim] @@ -330,9 +342,9 @@ void apply_rotary_embedding( int cos_stride = cos.stride(0); int sin_stride = sin.stride(0); - using m_scalar_t = typename colossalAI::common::ScalarTypeTrait::Type; + using MT = typename colossalAI::common::ScalarTypeTrait::Type; - int vec_size = get_vec_size(query); + int vec_size = get_vec_size(query); if ((head_dim / 2) % vec_size != 0) { // Disable vectorized loading optimization when head_dim is not divisible by VecSize. @@ -350,11 +362,11 @@ void apply_rotary_embedding( switch (vec_size) { case 1: - rotary_embedding_kernel<<>>( - query.data_ptr(), - key.data_ptr(), - cos.data_ptr(), - sin.data_ptr(), + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), query_stride, key_stride, shard_element_num / 2, @@ -366,11 +378,11 @@ void apply_rotary_embedding( ); break; case 2: - rotary_embedding_kernel<<>>( - query.data_ptr(), - key.data_ptr(), - cos.data_ptr(), - sin.data_ptr(), + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), query_stride, key_stride, shard_element_num / 2, @@ -382,11 +394,11 @@ void apply_rotary_embedding( ); break; case 4: - rotary_embedding_kernel<<>>( - query.data_ptr(), - key.data_ptr(), - cos.data_ptr(), - sin.data_ptr(), + rotary_embedding_kernel<<>>( + query.data_ptr(), + key.data_ptr(), + cos.data_ptr(), + sin.data_ptr(), query_stride, key_stride, shard_element_num / 2, @@ -416,21 +428,81 @@ void rotary_embedding_and_cache_copy( at::Tensor& block_tables, // [batch_size, max_seq_len] bool high_precision) { - DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( - high_precision, - query.scalar_type(), - "rotary_embedding_and_cache_copy", - apply_rotary_embedding_and_cache_copy( - query, - key, - value, - cos, - sin, - key_cache, - value_cache, - sequence_lengths, - block_tables - );) +#define _(T, CacheT, HIGH_PRECISION) \ + apply_rotary_embedding_and_cache_copy( \ + query, \ + key, \ + value, \ + cos, \ + sin, \ + key_cache, \ + value_cache, \ + sequence_lengths, \ + block_tables); + + if(key_cache.scalar_type() == at::ScalarType::Byte) + { + if(high_precision) { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, uint8_t, true) + break; + case at::ScalarType::Half: + _(half, uint8_t, true) + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, uint8_t, true) + break; + } + } + else { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, uint8_t, false) + break; + case at::ScalarType::Half: + _(half, uint8_t, false) + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, uint8_t, false) + break; + } + } + } + else + { + if(high_precision) { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, float, true) + break; + case at::ScalarType::Half: + _(half, half, true) + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, __nv_bfloat16, true) + break; + } + } + else { + switch (key.scalar_type()) + { + case at::ScalarType::Float: + _(float, float, false) + break; + case at::ScalarType::Half: + _(half, half, false) + break; + case at::ScalarType::BFloat16: + _(__nv_bfloat16, __nv_bfloat16, false) + break; + } + } + } +#undef _ } void rotary_embedding( diff --git a/extensions/csrc/kernel/cuda/utils/vec_copy.h b/extensions/csrc/kernel/cuda/utils/vec_copy.h index ad98361dd..7cc071c66 100644 --- a/extensions/csrc/kernel/cuda/utils/vec_copy.h +++ b/extensions/csrc/kernel/cuda/utils/vec_copy.h @@ -11,6 +11,7 @@ namespace colossalAI { namespace cuda { namespace utils { +// Note(LiuYang): Depreciated template __device__ __inline__ void copy_vector(T *dst, const T *src) { using VT = typename common::VecTypeTrait::Type; @@ -26,6 +27,7 @@ __device__ __inline__ void copy_vector(float *dst, const float *src) { *(reinterpret_cast(src + 4)); } +// Note(LiuYang): Depreciated template __device__ __inline__ void copy_zero_vector(T *dst) { using VT = typename common::VecTypeTrait::Type; @@ -36,13 +38,12 @@ template __device__ __inline__ void copy(const SrcT *src, DstT *dst) { using SrcVT = typename common::VecTypeTrait::Type; using DstVT = typename common::VecTypeTrait::Type; - // Note(LiuYang): Here static_cast can't be used for cast between two pointer *(reinterpret_cast(dst)) = funcs::CastFunctor()( *(reinterpret_cast(src))); } template -__device__ __inline__ void copy(const T *src, T *dst) { +__device__ __inline__ void copy(const T *src, T *dst) { using VT = typename common::VecTypeTrait::Type; *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); }