mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference/Feat] Feat quant kvcache step2 (#5674)
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#endif
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
@@ -175,6 +176,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, uint16_t, DEVICE, STMTS_WRAPPER({
|
||||
return res.x;
|
||||
}))
|
||||
|
||||
// half raw -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({
|
||||
__half_raw tmp;
|
||||
tmp.x = val;
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_halfraw_to_fp8(
|
||||
tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint8_t>(res);
|
||||
}))
|
||||
|
||||
// fp8x2 -> half2 raw
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||
union {
|
||||
@@ -222,6 +233,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, half, DEVICE, STMTS_WRAPPER({
|
||||
return half(res);
|
||||
}))
|
||||
|
||||
// half -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, uint8_t, DEVICE, STMTS_WRAPPER({
|
||||
__half_raw tmp(val);
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_halfraw_to_fp8(
|
||||
tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint8_t>(res);
|
||||
}))
|
||||
|
||||
// fp8x2 -> half2
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({
|
||||
__half2_raw res =
|
||||
@@ -230,6 +250,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({
|
||||
return half2(res);
|
||||
}))
|
||||
|
||||
// half2 -> fp8x2
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, uint16_t, DEVICE, STMTS_WRAPPER({
|
||||
__half2_raw tmp(val);
|
||||
__nv_fp8x2_storage_t res =
|
||||
__nv_cvt_halfraw2_to_fp8x2(
|
||||
tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint16_t>(res);
|
||||
}))
|
||||
|
||||
// fp8x4 -> half4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, dtype::half4, DEVICE, STMTS_WRAPPER({
|
||||
@@ -242,6 +271,20 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
return res;
|
||||
}))
|
||||
|
||||
// half4 -> fp8x4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::half4, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||
half2 x, y;
|
||||
x = val.x;
|
||||
y = val.y;
|
||||
uint16_t lo, hi;
|
||||
lo = CastFunctor<half2, uint16_t>()(x);
|
||||
hi = CastFunctor<half2, uint16_t>()(y);
|
||||
uint32_t res;
|
||||
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(lo), "h"(hi));
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x8 -> half8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint2, dtype::half8, DEVICE, STMTS_WRAPPER({
|
||||
@@ -314,6 +357,14 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
return res;
|
||||
}))
|
||||
|
||||
// float -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(
|
||||
val, __NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint8_t>(res);
|
||||
}))
|
||||
|
||||
// fp8x2 -> float2
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint16_t, float2, DEVICE, STMTS_WRAPPER({
|
||||
@@ -328,6 +379,28 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
return make_float2(lof, hif);
|
||||
}))
|
||||
|
||||
// float2 -> fp8x2
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
float2, uint16_t, DEVICE, STMTS_WRAPPER({
|
||||
uint16_t tmp1 =
|
||||
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x));
|
||||
uint16_t tmp2 =
|
||||
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y));
|
||||
uint16_t res = (tmp1 << 8U) | tmp2;
|
||||
return res;
|
||||
}))
|
||||
|
||||
// float4 -> fp8x4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||
uint32_t a, b, c, d;
|
||||
a = CastFunctor<float, uint8_t>()(val.x);
|
||||
b = CastFunctor<float, uint8_t>()(val.y);
|
||||
c = CastFunctor<float, uint8_t>()(val.z);
|
||||
d = CastFunctor<float, uint8_t>()(val.w);
|
||||
return (a << 24U) | (b << 16U) |
|
||||
(c << 8U) | d;
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4_
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
|
||||
@@ -338,6 +411,14 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, float4, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}))
|
||||
|
||||
// fp8x8 -> float8_
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
|
||||
@@ -352,16 +433,6 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
return res;
|
||||
}))
|
||||
|
||||
// half -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({
|
||||
__half_raw tmp;
|
||||
tmp.x = val;
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_halfraw_to_fp8(
|
||||
tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint8_t>(res);
|
||||
}))
|
||||
|
||||
// bf16 -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE,
|
||||
STMTS_WRAPPER({
|
||||
@@ -376,19 +447,24 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE,
|
||||
#endif
|
||||
}))
|
||||
|
||||
// float -> fp8
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({
|
||||
__nv_fp8_storage_t res =
|
||||
__nv_cvt_float_to_fp8(
|
||||
val, __NV_SATFINITE, __NV_E5M2);
|
||||
return static_cast<uint8_t>(res);
|
||||
}))
|
||||
|
||||
// fp8x4 -> float4
|
||||
// bf162 -> fp8x2
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
uint32_t, float4, DEVICE, STMTS_WRAPPER({
|
||||
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
__nv_bfloat162, uint16_t, DEVICE, STMTS_WRAPPER({
|
||||
uint16_t a =
|
||||
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x));
|
||||
uint16_t b =
|
||||
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y));
|
||||
return (a << 8U) | b;
|
||||
}))
|
||||
|
||||
// bf164 -> fp8x4
|
||||
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
|
||||
dtype::bfloat164, uint32_t, DEVICE, STMTS_WRAPPER({
|
||||
uint32_t res;
|
||||
uint16_t a, b;
|
||||
a = CastFunctor<__nv_bfloat162, uint16_t>()(val.x);
|
||||
b = CastFunctor<__nv_bfloat162, uint16_t>()(val.y);
|
||||
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(a), "h"(b));
|
||||
return res;
|
||||
}))
|
||||
|
||||
|
Reference in New Issue
Block a user