From 48c4f29b275e2d8105842913cd84f5d66c378b36 Mon Sep 17 00:00:00 2001 From: xs_courtesy Date: Tue, 19 Mar 2024 11:32:01 +0800 Subject: [PATCH] refactor vector utils --- .../cuda/decode_kv_cache_memcpy_kernel.cu | 2 +- .../cuda/fused_rotary_emb_and_cache_kernel.cu | 2 +- extensions/csrc/cuda/rms_layernorm_kernel.cu | 2 +- extensions/csrc/cuda/scaled_masked_softmax.h | 42 +----------- .../cuda/scaled_upper_triang_masked_softmax.h | 64 ------------------- .../{common => cuda/utils}/cuda_type_utils.h | 0 extensions/csrc/cuda/utils/vec_type_traits.h | 12 ++++ .../utils}/vector_copy_utils.h | 42 +++++++++++- 8 files changed, 57 insertions(+), 109 deletions(-) rename extensions/csrc/{common => cuda/utils}/cuda_type_utils.h (100%) create mode 100644 extensions/csrc/cuda/utils/vec_type_traits.h rename extensions/csrc/{common => cuda/utils}/vector_copy_utils.h (72%) diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 7eb44ecd0..3b1197a91 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -1,7 +1,7 @@ #include #include -#include "../common/vector_copy_utils.h" +#include "utils/vector_copy_utils.h" #include "../common/micros.h" template diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu index c1db06d3f..697dc7110 100644 --- a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -2,7 +2,7 @@ #include #include -#include "../common/vector_copy_utils.h" +#include "utils/vector_copy_utils.h" #include "../common/micros.h" template diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 8b250cb10..50f26510e 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -10,7 +10,7 @@ #include "block_reduce.h" #include "../common/micros.h" -#include "../common/cuda_type_utils.h" +#include "utils/cuda_type_utils.h" #define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ if (DATA_SIZE == 2) { \ diff --git a/extensions/csrc/cuda/scaled_masked_softmax.h b/extensions/csrc/cuda/scaled_masked_softmax.h index d3e6f04e6..cbbe7f36a 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax.h +++ b/extensions/csrc/cuda/scaled_masked_softmax.h @@ -6,52 +6,14 @@ #include #include #include -#include #include #include +#include "utils/vector_copy_utils.h" + namespace { -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} - int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h index 54c8e9133..524ef46c6 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h @@ -13,70 +13,6 @@ namespace { -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *dst = 0.0; -} - -template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *dst = 0.0; -} - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} - int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; diff --git a/extensions/csrc/common/cuda_type_utils.h b/extensions/csrc/cuda/utils/cuda_type_utils.h similarity index 100% rename from extensions/csrc/common/cuda_type_utils.h rename to extensions/csrc/cuda/utils/cuda_type_utils.h diff --git a/extensions/csrc/cuda/utils/vec_type_traits.h b/extensions/csrc/cuda/utils/vec_type_traits.h new file mode 100644 index 000000000..fddd1d5ac --- /dev/null +++ b/extensions/csrc/cuda/utils/vec_type_traits.h @@ -0,0 +1,12 @@ +#pragma once + +namespace colossalAI { +namespace cuda { +namespace utils { + +template +class VecTypeTraits {}; + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/common/vector_copy_utils.h b/extensions/csrc/cuda/utils/vector_copy_utils.h similarity index 72% rename from extensions/csrc/common/vector_copy_utils.h rename to extensions/csrc/cuda/utils/vector_copy_utils.h index 456440cf6..556036332 100644 --- a/extensions/csrc/common/vector_copy_utils.h +++ b/extensions/csrc/cuda/utils/vector_copy_utils.h @@ -1,11 +1,12 @@ +#pragma once + #include #include +#include #include -#include "string" - template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); @@ -57,6 +58,18 @@ __device__ __inline__ void copy_vector(c10::Half *dst, *((float4 *)dst) = *((float4 *)src); } +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} + template <> __device__ __inline__ void copy_vector(float *dst, const float *src) { *dst = *src; @@ -80,6 +93,31 @@ __device__ __inline__ void copy_vector(float *dst, const float *src) { *((float4 *)(dst + 4)) = *((float4 *)(src + 4)); } +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *dst = 0.0; +} + +template <> +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *dst = 0.0; +} + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} + template int get_vec_size(const torch::Tensor &tensor) { uint64_t address = reinterpret_cast(tensor.data_ptr());