[Inference/Feat] Add kvcache quantization support for FlashDecoding (#5656)

This commit is contained in:
傅剑寒
2024-04-26 19:40:37 +08:00
committed by GitHub
parent 5be590b99e
commit 8ccb6714e7
5 changed files with 482 additions and 174 deletions

View File

@@ -174,13 +174,13 @@ void context_kv_cache_memcpy(
key.scalar_type(),
"context_kv_cache_memcpy",
apply_context_kv_cache_memcpy<scalar_t>(
key,
value,
key_cache,
value_cache,
sequence_lengths,
cu_seqlens,
block_tables,
max_seq_len_in_batch
);)
key,
value,
key_cache,
value_cache,
sequence_lengths,
cu_seqlens,
block_tables,
max_seq_len_in_batch
);)
}

View File

@@ -5,7 +5,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include "common/micros.h"
#include "funcs/cast_functor.h"
@@ -34,11 +33,25 @@ constexpr unsigned int nextHighestPowerOf2(unsigned int v) {
return v;
}
template <typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ii++) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
using colossalAI::funcs::BinaryOpType;
using colossalAI::funcs::CastFunctor;
using colossalAI::funcs::TernaryOpFunctor;
using colossalAI::funcs::TernaryOpType;
using colossalAI::funcs::zero;
using colossalAI::common::VecTypeTrait;
using colossalAI::common::FloatVecTypeTrait;
using namespace colossalAI::cuda::attention;
@@ -84,10 +97,12 @@ __global__ void flash_decoding_attention_kernel(
constexpr int NUM_ROWS_PER_ROUNDS = MIN(WARP_SIZE / NUM_THREADS_PER_X, BLOCK_SIZE);
constexpr int NUM_VECS_PER_THREAD = NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN / WARP_SIZE;
using K_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using V_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using L_vec = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;
using KVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using VVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using KQuantVecT = typename VecTypeTrait<cache_t, VEC_SIZE>::Type;
using VQuantVecT = typename VecTypeTrait<cache_t, VEC_SIZE>::Type;
using LVecT = typename VecTypeTrait<scalar_t, VEC_SIZE>::Type;
using FloatVecT = typename FloatVecTypeTrait<LVecT>::Type;
const int context_len = context_lens[seq_idx];
const int thread_group_offset = lane % NUM_THREADS_PER_X;
@@ -119,18 +134,18 @@ __global__ void flash_decoding_attention_kernel(
scalar_t* q_shared_ptr = reinterpret_cast<scalar_t*>(q_shared);
// each warp access a whole block
K_vec q_vecs[NUM_VECS_PER_THREAD];
KVecT q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int idx = lane, i = 0; idx < NUM_ROWS_PER_ROUNDS * NUM_VECS_PER_TOKEN; idx += WARP_SIZE, i += 1) {
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
const int offset1 = idx % NUM_THREADS_PER_X;
q_vecs[i] = *reinterpret_cast<K_vec*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
q_vecs[i] = *reinterpret_cast<KVecT*>(q_shared_ptr + offset0 * x + offset1 * VEC_SIZE);
}
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
K_vec k_vecs[NUM_VECS_PER_THREAD];
KVecT k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < BLOCK_SIZE; i += NUM_ROWS_PER_ROUNDS) {
@@ -142,7 +157,7 @@ __global__ void flash_decoding_attention_kernel(
const int offset0 = idx / NUM_THREADS_PER_X / NUM_ROWS_PER_ROUNDS;
const int offset1 = (idx / NUM_THREADS_PER_X) % NUM_ROWS_PER_ROUNDS;
const int offset2 = idx % NUM_THREADS_PER_X;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE);
k_vecs[j] = CastFunctor<KQuantVecT, KVecT>()(*reinterpret_cast<const KQuantVecT*>(k_ptr + offset0 * BLOCK_SIZE * x + offset1 * x + offset2 * VEC_SIZE));
}
float qk = scale * Qk_dot<scalar_t, NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X, NUM_THREADS_PER_X>::dot(q_vecs, k_vecs);
@@ -174,13 +189,13 @@ __global__ void flash_decoding_attention_kernel(
}
__syncthreads();
Float_vec accs[NUM_ROUNDS_PER_TOKEN];
FloatVecT accs[NUM_ROUNDS_PER_TOKEN];
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
zero(accs[i]);
}
V_vec zero_value;
VVecT zero_value;
zero(zero_value);
for (int block_idx = warp_idx; block_idx < num_context_blocks; block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table_shared[block_idx]);
@@ -193,11 +208,11 @@ __global__ void flash_decoding_attention_kernel(
+ kv_head_idx * kv_head_stride
+ idx * VEC_SIZE;
V_vec v_vecs[NUM_ROUNDS_PER_TOKEN];
VVecT v_vecs[NUM_ROUNDS_PER_TOKEN];
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
v_vecs[i] = (reinterpret_cast<const V_vec*>(v_ptr))[i * WARP_SIZE];
v_vecs[i] = CastFunctor<VQuantVecT, VVecT>()(*((reinterpret_cast<const VQuantVecT*>(v_ptr) + i * WARP_SIZE)));
}
if (token_idx >= context_len) {
@@ -210,7 +225,7 @@ __global__ void flash_decoding_attention_kernel(
logit = CastFunctor<float, scalar_t>()(logits[token_idx]);
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
accs[i] = TernaryOpFunctor<scalar_t, V_vec, Float_vec, TernaryOpType::kFma>()(logit, v_vecs[i], accs[i]);
accs[i] = TernaryOpFunctor<scalar_t, VVecT, FloatVecT, TernaryOpType::kFma>()(logit, v_vecs[i], accs[i]);
}
}
}
@@ -220,16 +235,16 @@ __global__ void flash_decoding_attention_kernel(
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
block_sum<Float_vec, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
block_sum<FloatVecT, NUM_WARPS, NUM_THREADS_PER_TOKEN, VEC_SIZE>(out_shared_mem, accs[i]);
}
scalar_t* out_ptr = out + seq_idx * q_stride + head_idx * HEAD_SIZE;
L_vec out_reg;
LVecT out_reg;
#pragma unroll
for (int i = 0; i < NUM_ROUNDS_PER_TOKEN; i++) {
if (thread_idx < NUM_THREADS_PER_TOKEN) {
out_reg = CastFunctor<Float_vec, L_vec>()(accs[i]);
(reinterpret_cast<L_vec*>(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg;
out_reg = CastFunctor<FloatVecT, LVecT>()(accs[i]);
(reinterpret_cast<LVecT*>(out_ptr))[thread_idx + i * NUM_THREADS_PER_TOKEN] = out_reg;
}
}
}
@@ -353,18 +368,40 @@ void flash_decoding_attention(
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
float scale) {
switch (query.scalar_type()) {
case at::ScalarType::Float:
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float);
break;
case at::ScalarType::Half:
CALL_V1_LAUNCHER_BLOCK_SIZE(half, half);
break;
case at::ScalarType::BFloat16:
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16);
break;
default:
AT_ERROR("Unsupported data type: ", toString(query.scalar_type()));
TORCH_CHECK(query.scalar_type() == at::ScalarType::Float || query.scalar_type() == at::ScalarType::Half || query.scalar_type() == at::ScalarType::BFloat16,
"Dtype of query should be float, half or bfloat16!");
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte || key_cache.scalar_type() == key_cache.scalar_type(),
"Dtype of query and kvcache should be the same unless dtype of kvcache is fp8!");
if(key_cache.scalar_type() == at::ScalarType::Byte)
{
switch (query.scalar_type()) {
case at::ScalarType::Float:
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t);
break;
case at::ScalarType::Half:
CALL_V1_LAUNCHER_BLOCK_SIZE(half, uint8_t);
break;
case at::ScalarType::BFloat16:
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t);
break;
}
}
else
{
switch (query.scalar_type()) {
case at::ScalarType::Float:
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float);
break;
case at::ScalarType::Half:
CALL_V1_LAUNCHER_BLOCK_SIZE(half, half);
break;
case at::ScalarType::BFloat16:
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16);
break;
}
}
}