[Inference/Feat] Feat quant kvcache step2 (#5674)

This commit is contained in:
傅剑寒
2024-04-30 11:26:36 +08:00
committed by GitHub
parent 8ccb6714e7
commit 808ee6e4ad
4 changed files with 208 additions and 71 deletions

View File

@@ -9,6 +9,7 @@
#endif
#include <assert.h>
#include <stdint.h>
#include <functional>
@@ -175,6 +176,16 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, uint16_t, DEVICE, STMTS_WRAPPER({
return res.x;
}))
// half raw -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({
__half_raw tmp;
tmp.x = val;
__nv_fp8_storage_t res =
__nv_cvt_halfraw_to_fp8(
tmp, __NV_SATFINITE, __NV_E5M2);
return static_cast<uint8_t>(res);
}))
// fp8x2 -> half2 raw
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint32_t, DEVICE, STMTS_WRAPPER({
union {
@@ -222,6 +233,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint8_t, half, DEVICE, STMTS_WRAPPER({
return half(res);
}))
// half -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, uint8_t, DEVICE, STMTS_WRAPPER({
__half_raw tmp(val);
__nv_fp8_storage_t res =
__nv_cvt_halfraw_to_fp8(
tmp, __NV_SATFINITE, __NV_E5M2);
return static_cast<uint8_t>(res);
}))
// fp8x2 -> half2
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({
__half2_raw res =
@@ -230,6 +250,15 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, half2, DEVICE, STMTS_WRAPPER({
return half2(res);
}))
// half2 -> fp8x2
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, uint16_t, DEVICE, STMTS_WRAPPER({
__half2_raw tmp(val);
__nv_fp8x2_storage_t res =
__nv_cvt_halfraw2_to_fp8x2(
tmp, __NV_SATFINITE, __NV_E5M2);
return static_cast<uint16_t>(res);
}))
// fp8x4 -> half4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, dtype::half4, DEVICE, STMTS_WRAPPER({
@@ -242,6 +271,20 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
return res;
}))
// half4 -> fp8x4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::half4, uint32_t, DEVICE, STMTS_WRAPPER({
half2 x, y;
x = val.x;
y = val.y;
uint16_t lo, hi;
lo = CastFunctor<half2, uint16_t>()(x);
hi = CastFunctor<half2, uint16_t>()(y);
uint32_t res;
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(lo), "h"(hi));
return res;
}))
// fp8x8 -> half8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint2, dtype::half8, DEVICE, STMTS_WRAPPER({
@@ -314,6 +357,14 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
return res;
}))
// float -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(
val, __NV_SATFINITE, __NV_E5M2);
return static_cast<uint8_t>(res);
}))
// fp8x2 -> float2
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint16_t, float2, DEVICE, STMTS_WRAPPER({
@@ -328,6 +379,28 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
return make_float2(lof, hif);
}))
// float2 -> fp8x2
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
float2, uint16_t, DEVICE, STMTS_WRAPPER({
uint16_t tmp1 =
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.x));
uint16_t tmp2 =
static_cast<uint16_t>(CastFunctor<float, uint8_t>()(val.y));
uint16_t res = (tmp1 << 8U) | tmp2;
return res;
}))
// float4 -> fp8x4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float4, uint32_t, DEVICE, STMTS_WRAPPER({
uint32_t a, b, c, d;
a = CastFunctor<float, uint8_t>()(val.x);
b = CastFunctor<float, uint8_t>()(val.y);
c = CastFunctor<float, uint8_t>()(val.z);
d = CastFunctor<float, uint8_t>()(val.w);
return (a << 24U) | (b << 16U) |
(c << 8U) | d;
}))
// fp8x4 -> float4_
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, dtype::float4_, DEVICE, STMTS_WRAPPER({
@@ -338,6 +411,14 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
return res;
}))
// fp8x4 -> float4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, float4, DEVICE, STMTS_WRAPPER({
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}))
// fp8x8 -> float8_
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint2, dtype::float8_, DEVICE, STMTS_WRAPPER({
@@ -352,16 +433,6 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
return res;
}))
// half -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(uint16_t, uint8_t, DEVICE, STMTS_WRAPPER({
__half_raw tmp;
tmp.x = val;
__nv_fp8_storage_t res =
__nv_cvt_halfraw_to_fp8(
tmp, __NV_SATFINITE, __NV_E5M2);
return static_cast<uint8_t>(res);
}))
// bf16 -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE,
STMTS_WRAPPER({
@@ -376,19 +447,24 @@ COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(__nv_bfloat16, uint8_t, DEVICE,
#endif
}))
// float -> fp8
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, uint8_t, DEVICE, STMTS_WRAPPER({
__nv_fp8_storage_t res =
__nv_cvt_float_to_fp8(
val, __NV_SATFINITE, __NV_E5M2);
return static_cast<uint8_t>(res);
}))
// fp8x4 -> float4
// bf162 -> fp8x2
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
uint32_t, float4, DEVICE, STMTS_WRAPPER({
dtype::float4_ tmp = CastFunctor<uint32_t, dtype::float4_>()(val);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
__nv_bfloat162, uint16_t, DEVICE, STMTS_WRAPPER({
uint16_t a =
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.x));
uint16_t b =
static_cast<uint16_t>(CastFunctor<__nv_bfloat16, uint8_t>()(val.y));
return (a << 8U) | b;
}))
// bf164 -> fp8x4
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(
dtype::bfloat164, uint32_t, DEVICE, STMTS_WRAPPER({
uint32_t res;
uint16_t a, b;
a = CastFunctor<__nv_bfloat162, uint16_t>()(val.x);
b = CastFunctor<__nv_bfloat162, uint16_t>()(val.y);
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(a), "h"(b));
return res;
}))