diff --git a/extensions/csrc/common/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h index 8ede2d448..2a767620a 100644 --- a/extensions/csrc/common/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -8,26 +8,22 @@ namespace colossalAI { namespace common { template -class MPTypeTrait { - public: +struct MPTypeTrait { using Type = float; }; template <> -class MPTypeTrait { - public: +struct MPTypeTrait { using Type = float; }; template <> -class MPTypeTrait { - public: +struct MPTypeTrait { using Type = float; }; template <> -class MPTypeTrait { - public: +struct MPTypeTrait { using Type = float; }; diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index a65a3df8e..372b30387 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -4,7 +4,6 @@ #include "../common/micros.h" #include "../common/mp_type_traits.h" -#include "utils/gpu_launch_config.h" template __device__ __forceinline__ T silu_kernel(const T& x) { diff --git a/extensions/csrc/cuda/utils/vec_type_traits.h b/extensions/csrc/cuda/utils/vec_type_traits.h index fddd1d5ac..3ddd64df9 100644 --- a/extensions/csrc/cuda/utils/vec_type_traits.h +++ b/extensions/csrc/cuda/utils/vec_type_traits.h @@ -1,11 +1,82 @@ #pragma once +#include +#include +#include + +#include + namespace colossalAI { namespace cuda { namespace utils { -template -class VecTypeTraits {}; +template +struct VecTypeTrait {}; + +template +struct VecTypeTrait { + using Type = T; +}; + +template <> +struct VecTypeTrait { + using Type = float; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = half; +}; + +template <> +struct VecTypeTrait { + using Type = half2; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; } // namespace utils } // namespace cuda diff --git a/extensions/csrc/cuda/utils/vector_copy_utils.h b/extensions/csrc/cuda/utils/vector_copy_utils.h index 556036332..3c3afa0b3 100644 --- a/extensions/csrc/cuda/utils/vector_copy_utils.h +++ b/extensions/csrc/cuda/utils/vector_copy_utils.h @@ -5,117 +5,28 @@ #include #include -#include +#include "vec_type_traits.h" -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float *)dst) = *((float *)src); -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float4 *)dst) = *((float4 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float *)dst) = *((float *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float4 *)dst) = *((float4 *)src); -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - *((float4 *)dst) = *((float4 *)src); +template +__device__ __inline__ void copy_vector(T *dst, const T *src) { + using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; + // Note(LiuYang): Here static_cast can't be used for cast between two pointer + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); } template <> __device__ __inline__ void copy_vector(float *dst, const float *src) { // Since the maximum memory alignment length is 128 bits, we choose float4 // here. - *((float4 *)dst) = *((float4 *)src); - *((float4 *)(dst + 4)) = *((float4 *)(src + 4)); + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst + 4)) = + *(reinterpret_cast(src + 4)); } -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *dst = 0.0; -} - -template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *dst = 0.0; -} - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); +template +__device__ __inline__ void copy_zero_vector(T *dst) { + using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; + *(reinterpret_cast(dst)) = {0.0}; } template @@ -126,6 +37,11 @@ int get_vec_size(const torch::Tensor &tensor) { const int vec_size = max_aligned_size / sizeof(T) / 8; + // Note(LiuYang): Performance of situation of which + // vec_size equals to 8 need to be profiled in the future + // if (address % (dtype_size * 8) == 0) { + // return std::min(8, vec_size); + // } if (address % (dtype_size * 4) == 0) { return std::min(4, vec_size); } else if (address % (dtype_size * 2) == 0) {