[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)

This commit is contained in:
Steve Luo
2024-05-06 10:55:34 +08:00
committed by GitHub
parent 537a3cbc4d
commit 725fbd2ed0
7 changed files with 112 additions and 147 deletions

View File

@@ -69,14 +69,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE,
dst.y = __floats2half2_rn(val.z, val.w);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::half4, DEVICE,
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::half4, float4, DEVICE,
STMTS_WRAPPER({
dtype::half4 dst;
dst.x = __float22half2_rn(val.x);
dst.y = __float22half2_rn(val.y);
float4 dst;
dst.x = __half2float(val.x.x);
dst.y = __half2float(val.x.y);
dst.z = __half2float(val.y.x);
dst.w = __half2float(val.y.y);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::half8, DEVICE,
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::half8, DEVICE,
STMTS_WRAPPER({
dtype::half8 dst;
dst.x = __float22half2_rn(val.x);
@@ -107,6 +109,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE,
__floats2bfloat162_rn(val.z, val.w);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::bfloat164, float4, DEVICE,
STMTS_WRAPPER({
float4 dst;
dst.x = __bfloat162float(val.x.x);
dst.y = __bfloat162float(val.x.y);
dst.z = __bfloat162float(val.y.x);
dst.w = __bfloat162float(val.y.y);
return dst;
}))
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
@@ -120,14 +131,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
STMTS_WRAPPER({
return __float22bfloat162_rn(val);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::bfloat164, DEVICE,
STMTS_WRAPPER({
dtype::bfloat164 dst;
dst.x = __float22bfloat162_rn(val.x);
dst.y = __float22bfloat162_rn(val.y);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::bfloat168, DEVICE,
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::bfloat168, DEVICE,
STMTS_WRAPPER({
dtype::bfloat168 dst;
dst.x = __float22bfloat162_rn(val.x);
@@ -155,14 +159,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE,
val.y);
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float4_, dtype::bfloat164, DEVICE, STMTS_WRAPPER({
dtype::bfloat164 dst;
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
return dst;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float8_, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
dtype::float8, dtype::bfloat168, DEVICE, STMTS_WRAPPER({
dtype::bfloat168 dst;
dst.x = __floats2bfloat162_rn(val.x.x, val.x.y);
dst.y = __floats2bfloat162_rn(val.y.x, val.y.y);
@@ -405,35 +402,27 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
(b << 8U) | a;
}))
// fp8x4 -> float4_
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
dtype::float4_ res;
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val));
res.y =
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val >> 16U));
return res;
}))
// fp8x4 -> float4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, float4, DEVICE, STMTS_WRAPPER({
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
float4 res;
res.x = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val));
res.y = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 8U));
res.z = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 16U));
res.w = CastFunctor<uint8_t, float>()(static_cast<uint8_t>(val >> 24U));
return res;
}))
// fp8x8 -> float8_
// fp8x8 -> float8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
dtype::float4_ tmp1, tmp2;
tmp1 = CastFunctor<uint32_t, dtype::float4_>()(val.x);
tmp2 = CastFunctor<uint32_t, dtype::float4_>()(val.y);
dtype::float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
uint2, dtype::float8, DEVICE, STMTS_WRAPPER({
dtype::float8 res;
res.x = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.x));
res.y =
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.x >> 16U));
res.z = CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.y));
res.w =
CastFunctor<uint16_t, float2>()(static_cast<uint16_t>(val.y >> 16U));
return res;
}))
@@ -482,34 +471,22 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({
return uint32;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, uint2, DEVICE,
STMTS_WRAPPER({
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint2, DEVICE, STMTS_WRAPPER({
uint2 b;
float2 c;
c.x = val.x.x;
c.y = val.x.y;
c.x = val.x;
c.y = val.y;
b.x = CastFunctor<float2, uint32_t>()(c);
c.x = val.y.x;
c.y = val.y.y;
c.x = val.z;
c.y = val.w;
b.y = CastFunctor<float2, uint32_t>()(c);
return b;
}))
// float4_ -> float4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, float4, DEVICE,
STMTS_WRAPPER({
float4 b;
b.x = val.x.x;
b.y = val.x.y;
b.z = val.y.x;
b.w = val.y.y;
return b;
}))
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::float8_, uint4, DEVICE, STMTS_WRAPPER({
dtype::float8, uint4, DEVICE, STMTS_WRAPPER({
uint4 b;
b.x = CastFunctor<float2, uint32_t>()(val.x);
b.y = CastFunctor<float2, uint32_t>()(val.y);