[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)

This commit is contained in:
Steve Luo
2024-05-06 10:55:34 +08:00
committed by GitHub
parent 537a3cbc4d
commit 725fbd2ed0
7 changed files with 112 additions and 147 deletions

View File

@@ -94,29 +94,27 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fma(cast(a), b, c);
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
dtype::half4, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
dtype::half4, dtype::half4, float4, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float4_ fd;
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
fd.x = fma(a.x, b.x, c.x);
fd.y = fma(a.y, b.y, c.y);
float4 fd;
CastFunctor<dtype::half4, float4> cast;
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma;
fd = fma(cast(a), cast(b), c);
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
half, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float4_ fd;
CastFunctor<half, half2> cast;
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
half2 s = cast(a);
fd.x = fma(s, b.x, c.x);
fd.y = fma(s, b.y, c.y);
half, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
float4 fd;
CastFunctor<half, float> cast0;
CastFunctor<dtype::half4, float4> cast1;
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma;
fd = fma(cast0(a), cast1(b), c);
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
dtype::half8, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
dtype::half8, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float8_ fd;
dtype::float8 fd;
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
fd.x = fma(a.x, b.x, c.x);
fd.y = fma(a.y, b.y, c.y);
@@ -125,9 +123,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
half, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
half, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float8_ fd;
dtype::float8 fd;
CastFunctor<half, half2> cast;
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
half2 s = cast(a);
@@ -160,33 +158,28 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fma(cast(a), b, c);
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
dtype::bfloat164, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
DEVICE, STMTS_WRAPPER({
dtype::float4_ fd;
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
TernaryOpType::kFma>
fma;
fd.x = fma(a.x, b.x, c.x);
fd.y = fma(a.y, b.y, c.y);
dtype::bfloat164, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
float4 fd;
CastFunctor<dtype::bfloat164, float4> cast;
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma;
fd = fma(cast(a), cast(b), c);
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
DEVICE, STMTS_WRAPPER({
dtype::float4_ fd;
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
TernaryOpType::kFma>
fma;
__nv_bfloat162 s = cast(a);
fd.x = fma(s, b.x, c.x);
fd.y = fma(s, b.y, c.y);
__nv_bfloat16, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
float4 fd;
CastFunctor<__nv_bfloat16, float> cast0;
CastFunctor<dtype::bfloat164, float4> cast1;
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma;
fd = fma(cast0(a), cast1(b), c);
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
dtype::bfloat168, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
dtype::bfloat168, dtype::bfloat168, dtype::float8, TernaryOpType::kFma,
DEVICE, STMTS_WRAPPER({
dtype::float8_ fd;
dtype::float8 fd;
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
TernaryOpType::kFma>
fma;
@@ -197,9 +190,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
return fd;
}))
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
__nv_bfloat16, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
DEVICE, STMTS_WRAPPER({
dtype::float8_ fd;
__nv_bfloat16, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE,
STMTS_WRAPPER({
dtype::float8 fd;
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
TernaryOpType::kFma>