mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)
This commit is contained in:
@@ -69,14 +69,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE,
|
||||
dst.y = __floats2half2_rn(val.z, val.w);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::half4, DEVICE,
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::half4, float4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::half4 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
dst.y = __float22half2_rn(val.y);
|
||||
float4 dst;
|
||||
dst.x = __half2float(val.x.x);
|
||||
dst.y = __half2float(val.x.y);
|
||||
dst.z = __half2float(val.y.x);
|
||||
dst.w = __half2float(val.y.y);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::half8, DEVICE,
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::half8, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::half8 dst;
|
||||
dst.x = __float22half2_rn(val.x);
|
||||
@@ -107,6 +109,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
|
||||
__floats2bfloat162_rn(val.z, val.w);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::bfloat164, float4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 dst;
|
||||
dst.x = __bfloat162float(val.x.x);
|
||||
dst.y = __bfloat162float(val.x.y);
|
||||
dst.z = __bfloat162float(val.y.x);
|
||||
dst.w = __bfloat162float(val.y.y);
|
||||
return dst;
|
||||
}))
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
@@ -120,14 +131,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
return __float22bfloat162_rn(val);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::bfloat164, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
dst.y = __float22bfloat162_rn(val.y);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::bfloat168, DEVICE,
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::bfloat168, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __float22bfloat162_rn(val.x);
|
||||
@@ -155,14 +159,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
|
||||
val.y);
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float4_, dtype::bfloat164, DEVICE, STMTS_WRAPPER({
|
||||
dtype::bfloat164 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
return dst;
|
||||
}))
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
|
||||
dtype::bfloat168 dst;
|
||||
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
|
||||
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
|
||||
@@ -405,35 +402,27 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||
(b << 8U) | a;
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4_
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ res;
|
||||
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val));
|
||||
res.y =
|
||||
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val >> 16U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, float4, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
float4 res;
|
||||
res.x = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val));
|
||||
res.y = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 8U));
|
||||
res.z = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 16U));
|
||||
res.w = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 24U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x8 -> float8_
|
||||
// fp8x8 -> float8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ tmp1, tmp2;
|
||||
tmp1 = CastFunctor<uint32_t, dtype::float4_>()(val.x);
|
||||
tmp2 = CastFunctor<uint32_t, dtype::float4_>()(val.y);
|
||||
dtype::float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
uint2, dtype::float8, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8 res;
|
||||
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.x));
|
||||
res.y =
|
||||
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.x >> 16U));
|
||||
res.z = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.y));
|
||||
res.w =
|
||||
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.y >> 16U));
|
||||
return res;
|
||||
}))
|
||||
|
||||
@@ -482,34 +471,22 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||
return uint32;
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, uint2, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint2, DEVICE, STMTS_WRAPPER({
|
||||
uint2 b;
|
||||
float2 c;
|
||||
c.x = val.x.x;
|
||||
c.y = val.x.y;
|
||||
c.x = val.x;
|
||||
c.y = val.y;
|
||||
b.x = CastFunctor<float2, uint32_t>()(c);
|
||||
|
||||
c.x = val.y.x;
|
||||
c.y = val.y.y;
|
||||
c.x = val.z;
|
||||
c.y = val.w;
|
||||
b.y = CastFunctor<float2, uint32_t>()(c);
|
||||
|
||||
return b;
|
||||
}))
|
||||
|
||||
// float4_ -> float4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, float4, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 b;
|
||||
b.x = val.x.x;
|
||||
b.y = val.x.y;
|
||||
b.z = val.y.x;
|
||||
b.w = val.y.y;
|
||||
return b;
|
||||
}))
|
||||
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::float8_, uint4, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8, uint4, DEVICE, STMTS_WRAPPER({
|
||||
uint4 b;
|
||||
b.x = CastFunctor<float2, uint32_t>()(val.x);
|
||||
b.y = CastFunctor<float2, uint32_t>()(val.y);
|
||||
|
Reference in New Issue
Block a user