mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[Inference/Refactor] Refactor compilation mechanism and unified multi hw (#5613)
* refactor compilation mechanism and unified multi hw * fix file path bug * add init.py to make pybind a module to avoid relative path error caused by softlink * delete duplicated micros * fix micros bug in gcc
This commit is contained in:
210
extensions/csrc/funcs/binary_functor.h
Normal file
210
extensions/csrc/funcs/binary_functor.h
Normal file
@@ -0,0 +1,210 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "cast_functor.h"
|
||||
#include "common/data_type.h"
|
||||
#include "common/micros.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace funcs {
|
||||
|
||||
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };
|
||||
|
||||
// Note(LiuYang): This file provides base math operation for data type
|
||||
// include POD and cuda built-in type such as half and __nv_bfloat16.
|
||||
// Implementation of common and simple binary operators should be placed here,
|
||||
// otherwise, they should be placed in a new file under functors dir.
|
||||
template <typename LT, typename RT, typename RET, BinaryOpType op_type>
|
||||
struct BinaryOpFunctor;
|
||||
|
||||
#define STMTS_WRAPPER(...) __VA_ARGS__
|
||||
|
||||
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( \
|
||||
LT, RT, RET, BINARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct BinaryOpFunctor<LT, RT, RET, BINARY_OP_TYPE> \
|
||||
: public std::binary_function<LT, RT, RET> { \
|
||||
FUNCTION_MODIFIER RET operator()(LT lhs, RT rhs) STMTS \
|
||||
};
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kAdd, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return lhs + rhs; }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMinus,
|
||||
HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return lhs - rhs; }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMul, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return lhs * rhs; }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kDiv, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return lhs / rhs; }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMax, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return max(lhs, rhs); }),
|
||||
typename T)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, T, T, BinaryOpType::kMin, HOSTDEVICE,
|
||||
STMTS_WRAPPER({ return min(lhs, rhs); }),
|
||||
typename T)
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd(lhs, rhs);
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, half2, half2, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd2(lhs, rhs);
|
||||
}))
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
|
||||
__nv_bfloat16, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd(lhs, rhs);
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
|
||||
__nv_bfloat162, BinaryOpType::kAdd,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hadd2(lhs, rhs);
|
||||
}))
|
||||
#else
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kAdd, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16(__bfloat162float(lhs) + __bfloat162float(rhs));
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kAdd, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __floats2bfloat162_rn(__low2float(lhs) + __low2float(rhs),
|
||||
__high2float(lhs) + __high2float(rhs));
|
||||
}))
|
||||
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, half, half, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hmul(lhs, rhs);
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, half2, half2, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hmul2(lhs, rhs);
|
||||
}))
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat16,
|
||||
__nv_bfloat16, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hmul(lhs, rhs);
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, __nv_bfloat162,
|
||||
__nv_bfloat162, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
return __hmul2(lhs, rhs);
|
||||
}))
|
||||
#else
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float2bfloat16(__bfloat162float(lhs) * __bfloat162float(rhs));
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, __nv_bfloat162, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __floats2bfloat162_rn(__low2float(lhs) * __low2float(rhs),
|
||||
__high2float(lhs) * __high2float(rhs));
|
||||
}))
|
||||
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
float2, float2, float2, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({ return make_float2(lhs.x * rhs.x, lhs.y * rhs.y); }))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(float4, float4, float4,
|
||||
BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return make_float4(
|
||||
lhs.x * rhs.x, lhs.y * rhs.y,
|
||||
lhs.z * rhs.z, lhs.w * rhs.w);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, float2, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
CastFunctor<__nv_bfloat162, float2> cast;
|
||||
BinaryOpFunctor<float2, float2, float2, BinaryOpType::kMul> mul;
|
||||
float2 fa = cast(lhs);
|
||||
float2 fb = cast(rhs);
|
||||
return mul(fa, fb);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat164, dtype::bfloat164, dtype::float4_, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fc;
|
||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
BinaryOpType::kMul>
|
||||
mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8_, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fc;
|
||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
BinaryOpType::kMul>
|
||||
mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
fc.z = mul(lhs.z, rhs.z);
|
||||
fc.w = mul(lhs.w, rhs.w);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
half2, half2, float2, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({
|
||||
CastFunctor<half2, float2> cast;
|
||||
BinaryOpFunctor<float2, float2, float2, BinaryOpType::kMul> mul;
|
||||
float2 fa = cast(lhs);
|
||||
float2 fb = cast(rhs);
|
||||
return mul(fa, fb);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half4, dtype::half4, dtype::float4_, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fc;
|
||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half8, dtype::half8, dtype::float8_, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fc;
|
||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
fc.z = mul(lhs.z, rhs.z);
|
||||
fc.w = mul(lhs.w, rhs.w);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||
|
||||
#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION
|
||||
#undef STMTS_WRAPPER
|
||||
} // namespace funcs
|
||||
} // namespace colossalAI
|
163
extensions/csrc/funcs/cast_functor.h
Normal file
163
extensions/csrc/funcs/cast_functor.h
Normal file
@@ -0,0 +1,163 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "common/data_type.h"
|
||||
#include "common/micros.h"
|
||||
|
||||
// Note(LiuYang): This file provides base math operation for data type
|
||||
// include POD and cuda built-in type such as half and __nv_bfloat16
|
||||
|
||||
namespace colossalAI {
|
||||
namespace funcs {
|
||||
|
||||
template <typename From, typename To>
|
||||
struct CastFunctor : public std::unary_function<From, To> {
|
||||
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
|
||||
};
|
||||
|
||||
#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMTS, \
|
||||
FUNCTION_MODIFIER) \
|
||||
template <> \
|
||||
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
|
||||
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
||||
};
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
int2, float2, { return make_float2(val.x, val.y); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, float2, { return make_float2(val, val); }, DEVICE)
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half2, float2, { return __half22float2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, half2, { return __float22half2_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, half, { return __float2half_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, half2, { return __float2half2_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half, half2, { return __half2half2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
half, float, { return __half2float(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4, dtype::half4,
|
||||
{
|
||||
dtype::half4 dst;
|
||||
dst.x = __floats2half2_rn(val.x, val.y);
|
||||
dst.y = __floats2half2_rn(val.z, val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float4_, dtype::half4,
|
||||
{
|
||||
dtype::half4 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, dtype::half8,
|
||||
{
|
||||
dtype::half8 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
dst.z = __float22half2_rn(val.z);
|
||||
dst.w = __float22half2_rn(val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, __nv_bfloat162, { return __float2bfloat162_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float, __nv_bfloat16, { return __float2bfloat16_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float4, dtype::bfloat164,
|
||||
{
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x, val.y);
|
||||
dst.y = __floats2bfloat162_rn(val.z, val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat162, { return __bfloat162bfloat162(val); },
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, float2, { return __bfloat1622float2(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, __nv_bfloat162, { return __float22bfloat162_rn(val); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float4_, dtype::bfloat164,
|
||||
{
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, dtype::bfloat168,
|
||||
{
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
dst.z = __float22bfloat162_rn(val.z);
|
||||
dst.w = __float22bfloat162_rn(val.w);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
#else
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat162,
|
||||
{
|
||||
__nv_bfloat162 dst;
|
||||
dst.x = val;
|
||||
dst.y = val;
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, float2,
|
||||
{ return make_float2(__low2float(val), __high2float(val)); }, DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, __nv_bfloat162, { return __floats2bfloat162_rn(val.x, val.y); },
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float4_, dtype::bfloat164,
|
||||
{
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, dtype::bfloat168,
|
||||
{
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
dst.z = __floats2bfloat162_rn(val.z.x, val.z.y);
|
||||
dst.w = __floats2bfloat162_rn(val.w.x, val.w.y);
|
||||
return dst;
|
||||
},
|
||||
DEVICE)
|
||||
#endif /* defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 */
|
||||
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||
|
||||
#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
|
||||
} // namespace funcs
|
||||
} // namespace colossalAI
|
94
extensions/csrc/funcs/reduce_function.h
Normal file
94
extensions/csrc/funcs/reduce_function.h
Normal file
@@ -0,0 +1,94 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "binary_functor.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace funcs {
|
||||
|
||||
const float kReduceFloatInfNeg = -100000000.f;
|
||||
const float kReduceFloatInfPos = 100000000.f;
|
||||
const unsigned int kWarpReduceMask = 0xffffffff;
|
||||
|
||||
enum class ReduceType { kMax = 0, kSum };
|
||||
|
||||
template <typename T, ReduceType rtype>
|
||||
struct GetOpForReduceType;
|
||||
|
||||
template <typename T>
|
||||
struct GetOpForReduceType<T, ReduceType::kMax> {
|
||||
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kMax>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GetOpForReduceType<T, ReduceType::kSum> {
|
||||
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kAdd>;
|
||||
};
|
||||
|
||||
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
|
||||
_Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \
|
||||
*(VAL_PTR + offset) = \
|
||||
OP(*(VAL_PTR + offset), \
|
||||
__shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \
|
||||
}
|
||||
|
||||
#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, WIDTH, OP, LANES) \
|
||||
_Pragma("unroll") for (int DELTA = (WIDTH >> 1); DELTA > 0; DELTA >>= 1) { \
|
||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
|
||||
}
|
||||
|
||||
#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, VAL_PTR, OP, LANES, DEFAULT_VALUE, \
|
||||
REDUCE_TYPE) \
|
||||
__shared__ T shm[LANES][32]; \
|
||||
int lane_id = threadIdx.x & 0x1f; \
|
||||
int warp_id = threadIdx.x >> 5; \
|
||||
\
|
||||
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
|
||||
if (lane_id == 0) { \
|
||||
for (int offset = 0; offset < LANES; ++offset) { \
|
||||
shm[offset][warp_id] = *(VAL_PTR + offset); \
|
||||
} \
|
||||
} \
|
||||
__syncthreads(); \
|
||||
\
|
||||
_Pragma("unroll") for (int offset = 0; offset < LANES; ++offset) { \
|
||||
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
|
||||
? shm[offset][lane_id] \
|
||||
: static_cast<T>(DEFAULT_VALUE); \
|
||||
} \
|
||||
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR);
|
||||
|
||||
template <typename T, ReduceType rtype, int lanes, int width = 32>
|
||||
__forceinline__ __device__ void warp_reduce(T* pval) {
|
||||
typename GetOpForReduceType<T, rtype>::Op op;
|
||||
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, width, op, lanes);
|
||||
}
|
||||
|
||||
template <typename T, ReduceType rtype>
|
||||
__forceinline__ __device__ constexpr T GetDefaultValueForBlockReduce() {
|
||||
if constexpr (rtype == ReduceType::kSum) {
|
||||
return static_cast<T>(0.0f);
|
||||
} else if constexpr (rtype == ReduceType::kMax) {
|
||||
return static_cast<T>(kReduceFloatInfNeg);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, ReduceType rtype, int lanes>
|
||||
__forceinline__ __device__ void block_reduce(T* pval) {
|
||||
constexpr T kDefaultValue = GetDefaultValueForBlockReduce<T, rtype>();
|
||||
typename GetOpForReduceType<T, rtype>::Op op;
|
||||
COLOSSAL_BLOCK_REDUCE_IMPL(T, pval, op, lanes, kDefaultValue, rtype);
|
||||
}
|
||||
|
||||
#undef COLOSSAL_SHFL_FUNCTION
|
||||
#undef COLOSSAL_WARP_REDUCE_IMPL
|
||||
#undef COLOSSAL_BLOCK_REDUCE_IMPL
|
||||
|
||||
} // namespace funcs
|
||||
} // namespace colossalAI
|
||||
|
||||
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
221
extensions/csrc/funcs/ternary_functor.h
Normal file
221
extensions/csrc/funcs/ternary_functor.h
Normal file
@@ -0,0 +1,221 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "cast_functor.h"
|
||||
#include "common/micros.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace funcs {
|
||||
|
||||
enum class TernaryOpType { kFma = 0 };
|
||||
|
||||
template <typename LT, typename RT, typename RET, TernaryOpType op_type>
|
||||
struct TernaryOpFunctor;
|
||||
|
||||
#define STMTS_WRAPPER(...) __VA_ARGS__
|
||||
|
||||
#define COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( \
|
||||
LT, RT, RET, TERNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct TernaryOpFunctor<LT, RT, RET, TERNARY_OP_TYPE> { \
|
||||
FUNCTION_MODIFIER RET operator()(LT a, RT b, RET c) STMTS \
|
||||
};
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float, float,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float d;
|
||||
d = fma(a, b, c);
|
||||
return d;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float2, float2, float2,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float2 d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
return d;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float2, float2,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float2 d;
|
||||
d.x = fma(a, b.x, c.x);
|
||||
d.y = fma(a, b.y, c.y);
|
||||
return d;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float4, float4, float4,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
d.z = fma(a.z, b.z, c.z);
|
||||
d.w = fma(a.w, b.w, c.w);
|
||||
return d;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(float, float4, float4,
|
||||
TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 d;
|
||||
d.x = fma(a, b.x, c.x);
|
||||
d.y = fma(a, b.y, c.y);
|
||||
d.z = fma(a, b.z, c.z);
|
||||
d.w = fma(a, b.w, c.w);
|
||||
return d;
|
||||
}))
|
||||
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, half, float, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({ return __half2float(a) * __half2float(b) + c; }))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half2, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
CastFunctor<half2, float2> cast;
|
||||
TernaryOpFunctor<float2, float2, float2, TernaryOpType::kFma> fma;
|
||||
float2 fa = cast(a);
|
||||
float2 fb = cast(b);
|
||||
return fma(fa, fb, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, half2, float2, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
return fma(cast(a), b, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half4, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
half2 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half8, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
fd.z = fma(a.z, b.z, c.z);
|
||||
fd.w = fma(a.w, b.w, c.w);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
half2 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
fd.z = fma(s, b.z, c.z);
|
||||
fd.w = fma(s, b.w, c.w);
|
||||
return fd;
|
||||
}))
|
||||
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat16, float, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({ return __bfloat162float(a) * __bfloat162float(b) + c; }))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
CastFunctor<__nv_bfloat162, float2> cast;
|
||||
TernaryOpFunctor<float2, float2, float2, TernaryOpType::kFma> fma;
|
||||
float2 fa = cast(a);
|
||||
float2 fb = cast(b);
|
||||
return fma(fa, fb, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, __nv_bfloat162, float2, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
return fma(cast(a), b, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat164, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
__nv_bfloat162 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
fd.z = fma(a.z, b.z, c.z);
|
||||
fd.w = fma(a.w, b.w, c.w);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
__nv_bfloat162 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
fd.z = fma(s, b.z, c.z);
|
||||
fd.w = fma(s, b.w, c.w);
|
||||
return fd;
|
||||
}))
|
||||
|
||||
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||
|
||||
#undef COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION
|
||||
#undef STMTS_WRAPPER
|
||||
|
||||
} // namespace funcs
|
||||
} // namespace colossalAI
|
88
extensions/csrc/funcs/unary_functor.h
Normal file
88
extensions/csrc/funcs/unary_functor.h
Normal file
@@ -0,0 +1,88 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "common/data_type.h"
|
||||
#include "common/micros.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace funcs {
|
||||
|
||||
template <typename T>
|
||||
inline __device__ void zero(T& dst) {
|
||||
constexpr int WORDS = sizeof(T) / 4;
|
||||
union {
|
||||
T raw;
|
||||
uint32_t words[WORDS];
|
||||
} tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < WORDS; ii++) {
|
||||
tmp.words[ii] = 0u;
|
||||
}
|
||||
dst = tmp.raw;
|
||||
}
|
||||
|
||||
// Note(LiuYang): As a retrieved table to check which operation is supported
|
||||
// already
|
||||
enum class UnaryOpType { kLog2Ceil = 0, kAbs, kSum };
|
||||
|
||||
// Note(LiuYang): Implementation of common and simple unary operators should be
|
||||
// placed here, otherwise, they should be placed in a new file under functors
|
||||
// dir.
|
||||
template <typename From, typename To, UnaryOpType op_type>
|
||||
struct UnaryOpFunctor;
|
||||
|
||||
#define COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION( \
|
||||
FROM, TO, UNARY_OP_TYPE, FUNCTION_MODIFIER, STMTS, ARGS...) \
|
||||
template <ARGS> \
|
||||
struct UnaryOpFunctor<FROM, TO, UNARY_OP_TYPE> \
|
||||
: public std::unary_function<FROM, TO> { \
|
||||
FUNCTION_MODIFIER TO operator()(FROM val) STMTS \
|
||||
};
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(
|
||||
T, T, UnaryOpType::kAbs, HOSTDEVICE, { return std::abs(val); }, typename T)
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(int, int, UnaryOpType::kLog2Ceil,
|
||||
HOSTDEVICE, {
|
||||
int log2_value = 0;
|
||||
while ((1 << log2_value) < val)
|
||||
++log2_value;
|
||||
return log2_value;
|
||||
})
|
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA)
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE,
|
||||
{ return val.x + val.y; })
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE,
|
||||
{ return val.x + val.y + val.z + val.w; })
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float4_, float, UnaryOpType::kSum,
|
||||
DEVICE, {
|
||||
return val.x.x + val.x.y + val.y.x +
|
||||
val.y.y;
|
||||
})
|
||||
|
||||
COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8_, float, UnaryOpType::kSum,
|
||||
DEVICE, {
|
||||
return val.x.x + val.x.y + val.y.x +
|
||||
val.y.y + val.z.x + val.z.y +
|
||||
val.w.x + val.w.y;
|
||||
})
|
||||
|
||||
#endif /* defined(COLOSSAL_WITH_CUDA) */
|
||||
|
||||
#undef COLOSSAL_UARY_FUNCTOR_SPECIALIZATION
|
||||
|
||||
} // namespace funcs
|
||||
} // namespace colossalAI
|
Reference in New Issue
Block a user