mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
refactor vector utils
This commit is contained in:
@@ -1,122 +0,0 @@
|
||||
/*
|
||||
* This code from NVIDIA FasterTransformer:
|
||||
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T add(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half2 add(half2 a, half2 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half add(half a, half b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
return bf16hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
return bf16hadd(a, b);
|
||||
}
|
||||
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T mul(T a, T b, T c) {
|
||||
return a * b * c;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ half2 mul(half2 a, half2 b, half2 c) {
|
||||
return __hmul2(__hmul2(a, b), c);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b,
|
||||
__nv_bfloat16 c) {
|
||||
return bf16hmul(a, b, c);
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||
__nv_bfloat162 c) {
|
||||
return bf16hmul2(a, b, c);
|
||||
}
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, int2>(int2 val) {
|
||||
return make_float2(val.x, val.y);
|
||||
}
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, float>(float val) {
|
||||
return make_float2(val, val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline float2 cuda_cast<float2, half2>(half2 val) {
|
||||
return __half22float2(val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, float2>(float2 val) {
|
||||
return __float22half2_rn(val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, float>(float val) {
|
||||
return __float2half2_rn(val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline half2 cuda_cast<half2, half>(half val) {
|
||||
return __half2half2(val);
|
||||
}
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, half>(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||
template <typename T>
|
||||
struct TypeConverter {
|
||||
using Type = half2;
|
||||
}; // keep for generality
|
||||
|
||||
template <>
|
||||
struct TypeConverter<half2> {
|
||||
using Type = at::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<at::Half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
struct TypeConverter<__nv_bfloat162> {
|
||||
using Type = at::BFloat16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeConverter<at::BFloat16> {
|
||||
using Type = __nv_bfloat162;
|
||||
};
|
||||
#endif // ENABLE_BF16
|
@@ -1,98 +0,0 @@
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cfloat>
|
||||
|
||||
#include "string"
|
||||
|
||||
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 2>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float *)dst) = *((float *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::BFloat16, 8>(
|
||||
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||
*((float4 *)dst) = *((float4 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float *)dst) = *((float *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst,
|
||||
const c10::Half *src) {
|
||||
*((float4 *)dst) = *((float4 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) {
|
||||
*((float4 *)dst) = *((float4 *)src);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
||||
// Since the maximum memory alignment length is 128 bits, we choose float4
|
||||
// here.
|
||||
*((float4 *)dst) = *((float4 *)src);
|
||||
*((float4 *)(dst + 4)) = *((float4 *)(src + 4));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int get_vec_size(const torch::Tensor &tensor) {
|
||||
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
|
||||
const int max_aligned_size = 128;
|
||||
const int dtype_size = sizeof(T) * 8;
|
||||
|
||||
const int vec_size = max_aligned_size / sizeof(T) / 8;
|
||||
|
||||
if (address % (dtype_size * 4) == 0) {
|
||||
return std::min(4, vec_size);
|
||||
} else if (address % (dtype_size * 2) == 0) {
|
||||
return std::min(2, vec_size);
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user