[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)

This commit is contained in:
Steve Luo
2024-05-06 10:55:34 +08:00
committed by GitHub
parent 537a3cbc4d
commit 725fbd2ed0
7 changed files with 112 additions and 147 deletions

View File

@@ -164,22 +164,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
return mul(fa, fb);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
dtype::bfloat164, dtype::bfloat164, dtype::float4_, BinaryOpType::kMul,
DEVICE, STMTS_WRAPPER({
dtype::float4_ fc;
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
BinaryOpType::kMul>
mul;
fc.x = mul(lhs.x, rhs.x);
fc.y = mul(lhs.y, rhs.y);
return fc;
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::bfloat164, dtype::bfloat164,
float4, BinaryOpType::kMul, DEVICE,
STMTS_WRAPPER({
float4 fc;
CastFunctor<__nv_bfloat16, float> cast;
fc.x = cast(lhs.x.x) * cast(rhs.x.x);
fc.y = cast(lhs.x.y) * cast(rhs.x.y);
fc.z = cast(lhs.y.x) * cast(rhs.y.x);
fc.w = cast(lhs.y.y) * cast(rhs.y.y);
return fc;
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
dtype::bfloat168, dtype::bfloat168, dtype::float8_, BinaryOpType::kMul,
dtype::bfloat168, dtype::bfloat168, dtype::float8, BinaryOpType::kMul,
DEVICE, STMTS_WRAPPER({
dtype::float8_ fc;
dtype::float8 fc;
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
BinaryOpType::kMul>
mul;
@@ -199,20 +199,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
return mul(fa, fb);
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
dtype::half4, dtype::half4, dtype::float4_, BinaryOpType::kMul, DEVICE,
STMTS_WRAPPER({
dtype::float4_ fc;
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
fc.x = mul(lhs.x, rhs.x);
fc.y = mul(lhs.y, rhs.y);
return fc;
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::half4, dtype::half4, float4,
BinaryOpType::kMul, DEVICE,
STMTS_WRAPPER({
float4 fc;
CastFunctor<half, float> cast;
fc.x = cast(lhs.x.x) * cast(rhs.x.x);
fc.y = cast(lhs.x.y) * cast(rhs.x.y);
fc.z = cast(lhs.y.x) * cast(rhs.y.x);
fc.w = cast(lhs.y.y) * cast(rhs.y.y);
return fc;
}))
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
dtype::half8, dtype::half8, dtype::float8_, BinaryOpType::kMul, DEVICE,
dtype::half8, dtype::half8, dtype::float8, BinaryOpType::kMul, DEVICE,
STMTS_WRAPPER({
dtype::float8_ fc;
dtype::float8 fc;
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
fc.x = mul(lhs.x, rhs.x);
fc.y = mul(lhs.y, rhs.y);