[Inference/Feat] Add kvcache quantization support for FlashDecoding (#5656)

This commit is contained in:
傅剑寒
2024-04-26 19:40:37 +08:00
committed by GitHub
parent 5be590b99e
commit 8ccb6714e7
5 changed files with 482 additions and 174 deletions

View File

@@ -5,6 +5,7 @@
#include <cuda_fp16.h>
#endif
#include <ATen/ATen.h>
#include <stdint.h>
#include "common/data_type.h"
@@ -27,6 +28,7 @@ struct FloatVecTypeTrait {};
VEC_TYPE_TRAITS_SPECIALIZATION(T, 1, T, typename T)
#if defined(COLOSSAL_WITH_CUDA)
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 1, __nv_bfloat16)
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 2, __nv_bfloat162)
VEC_TYPE_TRAITS_SPECIALIZATION(at::BFloat16, 4, float2)
@@ -35,18 +37,19 @@ VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 1, half)
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 2, half2)
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 4, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(at::Half, 8, float4)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, half)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, half2)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 2, uint16_t)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 4, uint32_t)
VEC_TYPE_TRAITS_SPECIALIZATION(uint8_t, 8, uint2)
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 2, __nv_bfloat162);
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 4, dtype::bfloat164);
VEC_TYPE_TRAITS_SPECIALIZATION(__nv_bfloat16, 8, dtype::bfloat168);
VEC_TYPE_TRAITS_SPECIALIZATION(half, 2, half2);
VEC_TYPE_TRAITS_SPECIALIZATION(half, 4, dtype::half4);
VEC_TYPE_TRAITS_SPECIALIZATION(half, 8, dtype::half8);
VEC_TYPE_TRAITS_SPECIALIZATION(float, 2, float2)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 4, float4)
VEC_TYPE_TRAITS_SPECIALIZATION(float, 8, dtype::float8_)
#endif /* defined(COLOSSAL_WITH_CUDA) */
#undef VEC_TYPE_TRAITS_SPECIALIZATION