From 9df016fc4520a5a5c95a11ed04a8ac62bde039c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 30 Apr 2024 19:38:00 +0800 Subject: [PATCH] [Inference] Fix quant bits order (#5681) --- extensions/csrc/funcs/cast_functor.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/extensions/csrc/funcs/cast_functor.h b/extensions/csrc/funcs/cast_functor.h index 6382d5271..170abd596 100644 --- a/extensions/csrc/funcs/cast_functor.h +++ b/extensions/csrc/funcs/cast_functor.h @@ -390,7 +390,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( static_cast(CastFunctor()(val.x)); uint16_t tmp2 = static_cast(CastFunctor()(val.y)); - uint16_t res = (tmp1 << 8U) | tmp2; + uint16_t res = (tmp2 << 8U) | tmp1; return res; })) @@ -401,8 +401,8 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({ b = CastFunctor()(val.y); c = CastFunctor()(val.z); d = CastFunctor()(val.w); - return (a << 24U) | (b << 16U) | - (c << 8U) | d; + return (d << 24U) | (c << 16U) | + (b << 8U) | a; })) // fp8x4 -> float4_ @@ -458,7 +458,7 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION( static_cast(CastFunctor<__nv_bfloat16, uint8_t>()(val.x)); uint16_t b = static_cast(CastFunctor<__nv_bfloat16, uint8_t>()(val.y)); - return (a << 8U) | b; + return (b << 8U) | a; })) // bf164 -> fp8x4