[Inference]Add CUDA KVCache Kernel (#5406)

* add cuda KVCache kernel

* annotation benchmark_kvcache_copy

* add use cuda

* fix import path

* move benchmark scripts to example/

* rm benchmark codes in test_kv_cache_memcpy.py

* rm redundancy codes

* rm redundancy codes

* pr was modified according to the review
This commit is contained in:
yuehuayingxueluo
2024-02-28 14:36:50 +08:00
committed by GitHub
parent 19061188c3
commit 600881a8ea
15 changed files with 348 additions and 75 deletions

View File

@@ -4,6 +4,7 @@ from .flash_attention import (
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
)
from .inference import InferenceOpsCudaExtension
from .layernorm import LayerNormCudaExtension
from .moe import MoeCudaExtension
from .optimizer import FusedOptimizerCudaExtension
@@ -15,6 +16,7 @@ ALL_EXTENSIONS = [
LayerNormCudaExtension,
MoeCudaExtension,
FusedOptimizerCudaExtension,
InferenceOpsCudaExtension,
ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
FlashAttentionDaoCudaExtension,
@@ -28,6 +30,7 @@ __all__ = [
"LayerNormCudaExtension",
"MoeCudaExtension",
"FusedOptimizerCudaExtension",
"InferenceOpsCudaExtension",
"ScaledMaskedSoftmaxCudaExtension",
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
"FlashAttentionDaoCudaExtension",

View File

@@ -0,0 +1,15 @@
#include <torch/extension.h>
void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
}

View File

@@ -0,0 +1,90 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <stdio.h>
#include "type_shim.h"
template<typename scalar_t>
__global__ void decode_kv_cache_memcpy_kernel(
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache,
scalar_t* __restrict__ value_cache,
const int* __restrict__ sequence_lengths,
const int* __restrict__ block_tables,
const int num_heads,
const int head_size,
const int block_size,
const int key_stride,
const int value_stride,
const int block_table_stride
)
{
const int seq_id = blockIdx.x;
const int seq_len = sequence_lengths[seq_id] - 1;
const int seq_id_in_block_table = seq_len / block_size;
const int block_offset = seq_len % block_size;
const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table];
const int hidden_size = num_heads * head_size;
if ( block_id < 0 ) {
return ;
}
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
const int head_id = i / head_size;
const int head_offset = i % head_size;
const int key_src_id = seq_id * key_stride + i;
const int value_src_id = seq_id * value_stride + i;
const int target_src_id = block_id * hidden_size * block_size
+ head_id * block_size * head_size
+ block_offset * head_size + head_offset;
key_cache[target_src_id] = key[key_src_id];
value_cache[target_src_id] = value[value_src_id];
}
}
void decode_kv_cache_memcpy(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables) // [batch_size, max_seq_len]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(2);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_table_stride = block_tables.stride(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
DISPATCH_FLOAT_HALF_AND_BFLOAT(
key.scalar_type(),
"decode_kv_cache_memcpy",
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
sequence_lengths.data_ptr<int>(),
block_tables.data_ptr<int>(),
num_heads,
head_size,
block_size,
key_stride,
value_stride,
block_table_stride
);)
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@@ -24,6 +24,27 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \

View File

@@ -1,7 +1,10 @@
import os
import time
from abc import abstractmethod
from pathlib import Path
from typing import List
from .base_extension import _Extension
from .cpp_extension import _CppExtension
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list

View File

@@ -0,0 +1,3 @@
from .inference_ops_cuda import InferenceOpsCudaExtension
__all__ = ["InferenceOpsCudaExtension"]

View File

@@ -0,0 +1,30 @@
from ..cuda_extension import _CudaExtension
from ..utils import get_cuda_cc_flag
class InferenceOpsCudaExtension(_CudaExtension):
def __init__(self):
super().__init__(name="inference_ops_cuda")
def sources_files(self):
ret = [
self.csrc_abs_path(fname)
for fname in [
"cuda/colossal_inference_C_frontend.cpp",
"cuda/decode_kv_cache_memcpy_kernel.cu",
]
]
return ret
def include_dirs(self):
ret = [self.get_cuda_home_include()]
return ret
def cxx_flags(self):
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
return ["-O3"] + version_dependent_macros
def nvcc_flags(self):
extra_cuda_flags = ["-lineinfo"]
extra_cuda_flags.extend(get_cuda_cc_flag())
return ["-O3", "--use_fast_math"] + extra_cuda_flags