mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[Inference] Remove unnecessary float4_ and rename float8_ to float8 (#5679)
This commit is contained in:
@@ -94,29 +94,27 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
return fma(cast(a), b, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half4, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
|
||||
dtype::half4, dtype::half4, float4, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
float4 fd;
|
||||
CastFunctor<dtype::half4, float4> cast;
|
||||
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma;
|
||||
fd = fma(cast(a), cast(b), c);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
half2 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
half, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({
|
||||
float4 fd;
|
||||
CastFunctor<half, float> cast0;
|
||||
CastFunctor<dtype::half4, float4> cast1;
|
||||
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma;
|
||||
fd = fma(cast0(a), cast1(b), c);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half8, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
|
||||
dtype::half8, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
dtype::float8 fd;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
@@ -125,9 +123,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
half, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE,
|
||||
half, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
dtype::float8 fd;
|
||||
CastFunctor<half, half2> cast;
|
||||
TernaryOpFunctor<half2, half2, float2, TernaryOpType::kFma> fma;
|
||||
half2 s = cast(a);
|
||||
@@ -160,33 +158,28 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
return fma(cast(a), b, c);
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat164, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
fd.x = fma(a.x, b.x, c.x);
|
||||
fd.y = fma(a.y, b.y, c.y);
|
||||
dtype::bfloat164, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 fd;
|
||||
CastFunctor<dtype::bfloat164, float4> cast;
|
||||
TernaryOpFunctor<float4, float4, float4, TernaryOpType::kFma> fma;
|
||||
fd = fma(cast(a), cast(b), c);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ fd;
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
__nv_bfloat162 s = cast(a);
|
||||
fd.x = fma(s, b.x, c.x);
|
||||
fd.y = fma(s, b.y, c.y);
|
||||
__nv_bfloat16, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
float4 fd;
|
||||
CastFunctor<__nv_bfloat16, float> cast0;
|
||||
CastFunctor<dtype::bfloat164, float4> cast1;
|
||||
TernaryOpFunctor<float, float4, float4, TernaryOpType::kFma> fma;
|
||||
fd = fma(cast0(a), cast1(b), c);
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
|
||||
dtype::bfloat168, dtype::bfloat168, dtype::float8, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
dtype::float8 fd;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
fma;
|
||||
@@ -197,9 +190,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
return fd;
|
||||
}))
|
||||
COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION(
|
||||
__nv_bfloat16, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma,
|
||||
DEVICE, STMTS_WRAPPER({
|
||||
dtype::float8_ fd;
|
||||
__nv_bfloat16, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
dtype::float8 fd;
|
||||
CastFunctor<__nv_bfloat16, __nv_bfloat162> cast;
|
||||
TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2,
|
||||
TernaryOpType::kFma>
|
||||
|
Reference in New Issue
Block a user