mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)
This commit is contained in:
@@ -164,22 +164,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
return mul(fa, fb);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat164, dtype::bfloat164, dtype::float4_, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fc;
|
||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
BinaryOpType::kMul>
|
||||
mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
return fc;
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::bfloat164, dtype::bfloat164,
|
||||
float4, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 fc;
|
||||
CastFunctor<__nv_bfloat16, float> cast;
|
||||
fc.x = cast(lhs.x.x) * cast(rhs.x.x);
|
||||
fc.y = cast(lhs.x.y) * cast(rhs.x.y);
|
||||
fc.z = cast(lhs.y.x) * cast(rhs.y.x);
|
||||
fc.w = cast(lhs.y.y) * cast(rhs.y.y);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8_, BinaryOpType::kMul,
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8, BinaryOpType::kMul,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fc;
|
||||
dtype::float8 fc;
|
||||
BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
BinaryOpType::kMul>
|
||||
mul;
|
||||
@@ -199,20 +199,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
return mul(fa, fb);
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half4, dtype::half4, dtype::float4_, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fc;
|
||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
return fc;
|
||||
}))
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::half4, dtype::half4, float4,
|
||||
BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 fc;
|
||||
CastFunctor<half, float> cast;
|
||||
fc.x = cast(lhs.x.x) * cast(rhs.x.x);
|
||||
fc.y = cast(lhs.x.y) * cast(rhs.x.y);
|
||||
fc.z = cast(lhs.y.x) * cast(rhs.y.x);
|
||||
fc.w = cast(lhs.y.y) * cast(rhs.y.y);
|
||||
return fc;
|
||||
}))
|
||||
|
||||
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half8, dtype::half8, dtype::float8_, BinaryOpType::kMul, DEVICE,
|
||||
dtype::half8, dtype::half8, dtype::float8, BinaryOpType::kMul, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fc;
|
||||
dtype::float8 fc;
|
||||
BinaryOpFunctor<half2, half2, float2, BinaryOpType::kMul> mul;
|
||||
fc.x = mul(lhs.x, rhs.x);
|
||||
fc.y = mul(lhs.y, rhs.y);
|
||||
|
Reference in New Issue
Block a user