diff --git a/extensions/csrc/common/data_type.h b/extensions/csrc/common/data_type.h index 1327c51d3..7cc7cfabb 100644 --- a/extensions/csrc/common/data_type.h +++ b/extensions/csrc/common/data_type.h @@ -40,14 +40,7 @@ struct half8 { #endif }; -struct float4_ { -#ifdef COLOSSAL_WITH_CUDA - float2 x; - float2 y; -#endif -}; - -struct float8_ { +struct float8 { #ifdef COLOSSAL_WITH_CUDA float2 x; float2 y; diff --git a/extensions/csrc/common/vec_type_traits.h b/extensions/csrc/common/vec_type_traits.h index f7e70e22c..9e12ab71b 100644 --- a/extensions/csrc/common/vec_type_traits.h +++ b/extensions/csrc/common/vec_type_traits.h @@ -49,7 +49,7 @@ VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4); VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8); VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2) VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4) -VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_) +VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8) #endif /* defined(COLOSSAL_WITH_CUDA) */ #undef VEC_TYPE_TRAITS_SPECIALIZATION @@ -64,11 +64,11 @@ VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_) FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float2, float2) FLOATVEC_TYPE_TRAITS_SPECIALIZATION(float4, float4) FLOATVEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat162, float2); -FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat164, dtype::float4_); -FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat168, dtype::float8_); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat164, float4); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::bfloat168, dtype::float8); FLOATVEC_TYPE_TRAITS_SPECIALIZATION(half2, float2); -FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half4, dtype::float4_); -FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half8, dtype::float8_); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half4, float4); +FLOATVEC_TYPE_TRAITS_SPECIALIZATION(dtype::half8, dtype::float8); #endif /* COLOSSAL_WITH_CUDA */ #undef FLOATVEC_TYPE_TRAITS_SPECIALIZATION diff --git a/extensions/csrc/funcs/binary_functor.h b/extensions/csrc/funcs/binary_functor.h index 822f131c2..90726a02f 100644 --- a/extensions/csrc/funcs/binary_functor.h +++ b/extensions/csrc/funcs/binary_functor.h @@ -164,22 +164,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( return mul(fa, fb); })) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( - dtype::bfloat164, dtype::bfloat164, dtype::float4_, BinaryOpType::kMul, - DEVICE, STMTS_WRAPPER({ - dtype::float4_ fc; - BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, - BinaryOpType::kMul> - mul; - fc.x = mul(lhs.x, rhs.x); - fc.y = mul(lhs.y, rhs.y); - return fc; - })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::bfloat164, dtype::bfloat164, + float4, BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + float4 fc; + CastFunctor<__nv_bfloat16, float> cast; + fc.x = cast(lhs.x.x) * cast(rhs.x.x); + fc.y = cast(lhs.x.y) * cast(rhs.x.y); + fc.z = cast(lhs.y.x) * cast(rhs.y.x); + fc.w = cast(lhs.y.y) * cast(rhs.y.y); + return fc; + })) COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( - dtype::bfloat168, dtype::bfloat168, dtype::float8_, BinaryOpType::kMul, + dtype::bfloat168, dtype::bfloat168, dtype::float8, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({ - dtype::float8_ fc; + dtype::float8 fc; BinaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, BinaryOpType::kMul> mul; @@ -199,20 +199,22 @@ COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( return mul(fa, fb); })) -COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( - dtype::half4, dtype::half4, dtype::float4_, BinaryOpType::kMul, DEVICE, - STMTS_WRAPPER({ - dtype::float4_ fc; - BinaryOpFunctor mul; - fc.x = mul(lhs.x, rhs.x); - fc.y = mul(lhs.y, rhs.y); - return fc; - })) +COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(dtype::half4, dtype::half4, float4, + BinaryOpType::kMul, DEVICE, + STMTS_WRAPPER({ + float4 fc; + CastFunctor cast; + fc.x = cast(lhs.x.x) * cast(rhs.x.x); + fc.y = cast(lhs.x.y) * cast(rhs.x.y); + fc.z = cast(lhs.y.x) * cast(rhs.y.x); + fc.w = cast(lhs.y.y) * cast(rhs.y.y); + return fc; + })) COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION( - dtype::half8, dtype::half8, dtype::float8_, BinaryOpType::kMul, DEVICE, + dtype::half8, dtype::half8, dtype::float8, BinaryOpType::kMul, DEVICE, STMTS_WRAPPER({ - dtype::float8_ fc; + dtype::float8 fc; BinaryOpFunctor mul; fc.x = mul(lhs.x, rhs.x); fc.y = mul(lhs.y, rhs.y); diff --git a/extensions/csrc/funcs/cast_functor.h b/extensions/csrc/funcs/cast_functor.h index 170abd596..588357d6b 100644 --- a/extensions/csrc/funcs/cast_functor.h +++ b/extensions/csrc/funcs/cast_functor.h @@ -69,14 +69,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::half4, DEVICE, dst.y = __floats2half2_rn(val.z, val.w); return dst; })) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::half4, DEVICE, +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::half4, float4, DEVICE, STMTS_WRAPPER({ - dtype::half4 dst; - dst.x = __float22half2_rn(val.x); - dst.y = __float22half2_rn(val.y); + float4 dst; + dst.x = __half2float(val.x.x); + dst.y = __half2float(val.x.y); + dst.z = __half2float(val.y.x); + dst.w = __half2float(val.y.y); return dst; })) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::half8, DEVICE, +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::half8, DEVICE, STMTS_WRAPPER({ dtype::half8 dst; dst.x = __float22half2_rn(val.x); @@ -107,6 +109,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, dtype::bfloat164, DEVICE, __floats2bfloat162_rn(val.z, val.w); return dst; })) +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::bfloat164, float4, DEVICE, + STMTS_WRAPPER({ + float4 dst; + dst.x = __bfloat162float(val.x.x); + dst.y = __bfloat162float(val.x.y); + dst.z = __bfloat162float(val.y.x); + dst.w = __bfloat162float(val.y.y); + return dst; + })) #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, __nv_bfloat162, DEVICE, STMTS_WRAPPER({ @@ -120,14 +131,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE, STMTS_WRAPPER({ return __float22bfloat162_rn(val); })) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, dtype::bfloat164, DEVICE, - STMTS_WRAPPER({ - dtype::bfloat164 dst; - dst.x = __float22bfloat162_rn(val.x); - dst.y = __float22bfloat162_rn(val.y); - return dst; - })) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8_, dtype::bfloat168, DEVICE, +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float8, dtype::bfloat168, DEVICE, STMTS_WRAPPER({ dtype::bfloat168 dst; dst.x = __float22bfloat162_rn(val.x); @@ -155,14 +159,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, __nv_bfloat162, DEVICE, val.y); })) COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float4_, dtype::bfloat164, DEVICE, STMTS_WRAPPER({ - dtype::bfloat164 dst; - dst.x = __floats2bfloat162_rn(val.x.x, val.x.y); - dst.y = __floats2bfloat162_rn(val.y.x, val.y.y); - return dst; - })) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float8_, dtype::bfloat168, DEVICE, STMTS_WRAPPER({ + dtype::float8, dtype::bfloat168, DEVICE, STMTS_WRAPPER({ dtype::bfloat168 dst; dst.x = __floats2bfloat162_rn(val.x.x, val.x.y); dst.y = __floats2bfloat162_rn(val.y.x, val.y.y); @@ -405,35 +402,27 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({ (b << 8U) | a; })) -// fp8x4 -> float4_ -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({ - dtype::float4_ res; - res.x = CastFunctor()(static_cast(val)); - res.y = - CastFunctor()(static_cast(val >> 16U)); - return res; - })) - // fp8x4 -> float4 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( uint32_t, float4, DEVICE, STMTS_WRAPPER({ - dtype::float4_ tmp = CastFunctor()(val); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + float4 res; + res.x = CastFunctor()(static_cast(val)); + res.y = CastFunctor()(static_cast(val >> 8U)); + res.z = CastFunctor()(static_cast(val >> 16U)); + res.w = CastFunctor()(static_cast(val >> 24U)); return res; })) -// fp8x8 -> float8_ +// fp8x8 -> float8 COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({ - dtype::float4_ tmp1, tmp2; - tmp1 = CastFunctor()(val.x); - tmp2 = CastFunctor()(val.y); - dtype::float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; + uint2, dtype::float8, DEVICE, STMTS_WRAPPER({ + dtype::float8 res; + res.x = CastFunctor()(static_cast(val.x)); + res.y = + CastFunctor()(static_cast(val.x >> 16U)); + res.z = CastFunctor()(static_cast(val.y)); + res.w = + CastFunctor()(static_cast(val.y >> 16U)); return res; })) @@ -482,34 +471,22 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, uint32_t, DEVICE, STMTS_WRAPPER({ return uint32; })) -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, uint2, DEVICE, - STMTS_WRAPPER({ +COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint2, DEVICE, STMTS_WRAPPER({ uint2 b; float2 c; - c.x = val.x.x; - c.y = val.x.y; + c.x = val.x; + c.y = val.y; b.x = CastFunctor()(c); - c.x = val.y.x; - c.y = val.y.y; + c.x = val.z; + c.y = val.w; b.y = CastFunctor()(c); return b; })) -// float4_ -> float4 -COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(dtype::float4_, float4, DEVICE, - STMTS_WRAPPER({ - float4 b; - b.x = val.x.x; - b.y = val.x.y; - b.z = val.y.x; - b.w = val.y.y; - return b; - })) - COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( - dtype::float8_, uint4, DEVICE, STMTS_WRAPPER({ + dtype::float8, uint4, DEVICE, STMTS_WRAPPER({ uint4 b; b.x = CastFunctor()(val.x); b.y = CastFunctor()(val.y); diff --git a/extensions/csrc/funcs/ternary_functor.h b/extensions/csrc/funcs/ternary_functor.h index c7d8039de..8d0c95f10 100644 --- a/extensions/csrc/funcs/ternary_functor.h +++ b/extensions/csrc/funcs/ternary_functor.h @@ -94,29 +94,27 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( return fma(cast(a), b, c); })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( - dtype::half4, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE, + dtype::half4, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ - dtype::float4_ fd; - TernaryOpFunctor fma; - fd.x = fma(a.x, b.x, c.x); - fd.y = fma(a.y, b.y, c.y); + float4 fd; + CastFunctor cast; + TernaryOpFunctor fma; + fd = fma(cast(a), cast(b), c); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( - half, dtype::half4, dtype::float4_, TernaryOpType::kFma, DEVICE, - STMTS_WRAPPER({ - dtype::float4_ fd; - CastFunctor cast; - TernaryOpFunctor fma; - half2 s = cast(a); - fd.x = fma(s, b.x, c.x); - fd.y = fma(s, b.y, c.y); + half, dtype::half4, float4, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ + float4 fd; + CastFunctor cast0; + CastFunctor cast1; + TernaryOpFunctor fma; + fd = fma(cast0(a), cast1(b), c); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( - dtype::half8, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE, + dtype::half8, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ - dtype::float8_ fd; + dtype::float8 fd; TernaryOpFunctor fma; fd.x = fma(a.x, b.x, c.x); fd.y = fma(a.y, b.y, c.y); @@ -125,9 +123,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( - half, dtype::half8, dtype::float8_, TernaryOpType::kFma, DEVICE, + half, dtype::half8, dtype::float8, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ - dtype::float8_ fd; + dtype::float8 fd; CastFunctor cast; TernaryOpFunctor fma; half2 s = cast(a); @@ -160,33 +158,28 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( return fma(cast(a), b, c); })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( - dtype::bfloat164, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma, - DEVICE, STMTS_WRAPPER({ - dtype::float4_ fd; - TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, - TernaryOpType::kFma> - fma; - fd.x = fma(a.x, b.x, c.x); - fd.y = fma(a.y, b.y, c.y); + dtype::bfloat164, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 fd; + CastFunctor cast; + TernaryOpFunctor fma; + fd = fma(cast(a), cast(b), c); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( - __nv_bfloat16, dtype::bfloat164, dtype::float4_, TernaryOpType::kFma, - DEVICE, STMTS_WRAPPER({ - dtype::float4_ fd; - CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; - TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, - TernaryOpType::kFma> - fma; - __nv_bfloat162 s = cast(a); - fd.x = fma(s, b.x, c.x); - fd.y = fma(s, b.y, c.y); + __nv_bfloat16, dtype::bfloat164, float4, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + float4 fd; + CastFunctor<__nv_bfloat16, float> cast0; + CastFunctor cast1; + TernaryOpFunctor fma; + fd = fma(cast0(a), cast1(b), c); return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( - dtype::bfloat168, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma, + dtype::bfloat168, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE, STMTS_WRAPPER({ - dtype::float8_ fd; + dtype::float8 fd; TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma> fma; @@ -197,9 +190,9 @@ COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( return fd; })) COLOSSAL_TERNARY_FUNCTOR_SPECIALIZATION( - __nv_bfloat16, dtype::bfloat168, dtype::float8_, TernaryOpType::kFma, - DEVICE, STMTS_WRAPPER({ - dtype::float8_ fd; + __nv_bfloat16, dtype::bfloat168, dtype::float8, TernaryOpType::kFma, DEVICE, + STMTS_WRAPPER({ + dtype::float8 fd; CastFunctor<__nv_bfloat16, __nv_bfloat162> cast; TernaryOpFunctor<__nv_bfloat162, __nv_bfloat162, float2, TernaryOpType::kFma> diff --git a/extensions/csrc/funcs/unary_functor.h b/extensions/csrc/funcs/unary_functor.h index ea75018df..207a0ff97 100644 --- a/extensions/csrc/funcs/unary_functor.h +++ b/extensions/csrc/funcs/unary_functor.h @@ -52,13 +52,7 @@ COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float2, float, UnaryOpType::kSum, DEVICE, COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(float4, float, UnaryOpType::kSum, DEVICE, { return val.x + val.y + val.z + val.w; }) -COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float4_, float, UnaryOpType::kSum, - DEVICE, { - return val.x.x + val.x.y + val.y.x + - val.y.y; - }) - -COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8_, float, UnaryOpType::kSum, +COLOSSAL_UNARY_FUNCTOR_SPECIALIZATION(dtype::float8, float, UnaryOpType::kSum, DEVICE, { return val.x.x + val.x.y + val.y.x + val.y.y + val.z.x + val.z.y + diff --git a/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu b/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu index c9bd3d72d..ca359df8d 100644 --- a/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/kernel/cuda/rms_layernorm_kernel.cu @@ -283,11 +283,14 @@ void rms_layernorm( case 4: RMSNORM_LAUNCHER(4, block); break; + case 5: + RMSNORM_LAUNCHER(5, block); + break; case 8: RMSNORM_LAUNCHER(8, block); break; default: - AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8"); + AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8"); } } } @@ -330,11 +333,14 @@ void fused_add_rms_layernorm( case 4: FUSED_ADD_RMSNORM_LAUNCHER(4, block); break; + case 5: + FUSED_ADD_RMSNORM_LAUNCHER(5, block); + break; case 8: FUSED_ADD_RMSNORM_LAUNCHER(8, block); break; default: - AT_ERROR("unroll_factor must be 1, 2, 3, 4 or 8"); + AT_ERROR("unroll_factor must be 1, 2, 3, 4, 5 or 8"); } } }