mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-29 08:47:17 +00:00
[Inference]Add CUDA KVCache Kernel (#5406)
* add cuda KVCache kernel * annotation benchmark_kvcache_copy * add use cuda * fix import path * move benchmark scripts to example/ * rm benchmark codes in test_kv_cache_memcpy.py * rm redundancy codes * rm redundancy codes * pr was modified according to the review
This commit is contained in:
parent
19061188c3
commit
600881a8ea
@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
|
|
||||||
from colossalai.inference.batch_bucket import BatchBucket
|
from colossalai.inference.batch_bucket import BatchBucket
|
||||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import (
|
||||||
context_attention_unpadded,
|
context_attention_unpadded,
|
||||||
decoding_fused_rotary_embedding,
|
decoding_fused_rotary_embedding,
|
||||||
@ -22,6 +23,8 @@ from colossalai.kernel.triton import (
|
|||||||
)
|
)
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
logger = get_dist_logger(__name__)
|
logger = get_dist_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -74,6 +77,12 @@ def llama_model_forward(
|
|||||||
sequence_lengths = batch.get_sequence_lengths()
|
sequence_lengths = batch.get_sequence_lengths()
|
||||||
batch_size = batch.current_batch_size
|
batch_size = batch.current_batch_size
|
||||||
kv_seq_len = sequence_lengths.max().item()
|
kv_seq_len = sequence_lengths.max().item()
|
||||||
|
use_cuda_kernel = True
|
||||||
|
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
||||||
|
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
||||||
|
# selection should be conducted.
|
||||||
|
if batch_size >= 32 and kv_seq_len > 512:
|
||||||
|
use_cuda_kernel = False
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
@ -107,6 +116,7 @@ def llama_model_forward(
|
|||||||
output_tensor=output_tensor,
|
output_tensor=output_tensor,
|
||||||
norm_output=norm_output,
|
norm_output=norm_output,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
|
use_cuda_kernel=use_cuda_kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch.is_prompts:
|
if batch.is_prompts:
|
||||||
@ -134,6 +144,7 @@ def llama_decoder_layer_forward(
|
|||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
norm_output: torch.Tensor = None,
|
norm_output: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
|
use_cuda_kernel: bool = True,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||||
|
|
||||||
@ -153,6 +164,7 @@ def llama_decoder_layer_forward(
|
|||||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||||
|
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
||||||
@ -169,6 +181,7 @@ def llama_decoder_layer_forward(
|
|||||||
fd_inter_tensor=fd_inter_tensor,
|
fd_inter_tensor=fd_inter_tensor,
|
||||||
output_tensor=output_tensor,
|
output_tensor=output_tensor,
|
||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
|
use_cuda_kernel=use_cuda_kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
@ -252,6 +265,7 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
fd_inter_tensor: FDIntermTensors = None,
|
fd_inter_tensor: FDIntermTensors = None,
|
||||||
output_tensor: torch.Tensor = None,
|
output_tensor: torch.Tensor = None,
|
||||||
sm_scale: int = None,
|
sm_scale: int = None,
|
||||||
|
use_cuda_kernel: bool = True,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -268,6 +282,7 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
storing intermediate values in flash-decoding. Defaults to None.
|
storing intermediate values in flash-decoding. Defaults to None.
|
||||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||||
|
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.num_heads != self.num_key_value_heads:
|
if self.num_heads != self.num_key_value_heads:
|
||||||
@ -283,7 +298,6 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
attn_output = context_attention_unpadded(
|
attn_output = context_attention_unpadded(
|
||||||
@ -300,17 +314,23 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
sm_scale=sm_scale,
|
sm_scale=sm_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decoding_fused_rotary_embedding(
|
if use_cuda_kernel:
|
||||||
query_states,
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
key_states,
|
inference_ops.decode_kv_cache_memcpy(
|
||||||
value_states,
|
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||||
cos_sin[0],
|
)
|
||||||
cos_sin[1],
|
else:
|
||||||
k_cache,
|
decoding_fused_rotary_embedding(
|
||||||
v_cache,
|
query_states,
|
||||||
block_tables,
|
key_states,
|
||||||
sequence_lengths,
|
value_states,
|
||||||
)
|
cos_sin[0],
|
||||||
|
cos_sin[1],
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
block_tables,
|
||||||
|
sequence_lengths,
|
||||||
|
)
|
||||||
attn_output = flash_decoding_attention(
|
attn_output = flash_decoding_attention(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
k_cache=k_cache,
|
k_cache=k_cache,
|
||||||
|
@ -8,6 +8,7 @@ from .extensions import (
|
|||||||
FlashAttentionNpuExtension,
|
FlashAttentionNpuExtension,
|
||||||
FlashAttentionXformersCudaExtension,
|
FlashAttentionXformersCudaExtension,
|
||||||
FusedOptimizerCudaExtension,
|
FusedOptimizerCudaExtension,
|
||||||
|
InferenceOpsCudaExtension,
|
||||||
LayerNormCudaExtension,
|
LayerNormCudaExtension,
|
||||||
MoeCudaExtension,
|
MoeCudaExtension,
|
||||||
ScaledMaskedSoftmaxCudaExtension,
|
ScaledMaskedSoftmaxCudaExtension,
|
||||||
@ -21,6 +22,7 @@ __all__ = [
|
|||||||
"LayerNormLoader",
|
"LayerNormLoader",
|
||||||
"MoeLoader",
|
"MoeLoader",
|
||||||
"FusedOptimizerLoader",
|
"FusedOptimizerLoader",
|
||||||
|
"InferenceOpsLoader",
|
||||||
"ScaledMaskedSoftmaxLoader",
|
"ScaledMaskedSoftmaxLoader",
|
||||||
"ScaledUpperTriangleMaskedSoftmaxLoader",
|
"ScaledUpperTriangleMaskedSoftmaxLoader",
|
||||||
]
|
]
|
||||||
@ -97,6 +99,10 @@ class FusedOptimizerLoader(KernelLoader):
|
|||||||
REGISTRY = [FusedOptimizerCudaExtension]
|
REGISTRY = [FusedOptimizerCudaExtension]
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceOpsLoader(KernelLoader):
|
||||||
|
REGISTRY = [InferenceOpsCudaExtension]
|
||||||
|
|
||||||
|
|
||||||
class ScaledMaskedSoftmaxLoader(KernelLoader):
|
class ScaledMaskedSoftmaxLoader(KernelLoader):
|
||||||
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
|
REGISTRY = [ScaledMaskedSoftmaxCudaExtension]
|
||||||
|
|
||||||
|
@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.inference.modeling.layers.attention import copy_to_cache
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton # noqa
|
||||||
|
except ImportError:
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
HEAD_DIM = 4
|
||||||
|
BATCH = 16
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
SAME_LEN = True
|
||||||
|
WARM_UPS = 10
|
||||||
|
REPS = 100
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["KV_SEQ_LEN"],
|
||||||
|
x_vals=[2**i for i in range(8, 13)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
|
||||||
|
line_names=["torch_copy_func", "triton_copy_func", "cuda_copy_func"],
|
||||||
|
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
|
||||||
|
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def benchmark_kvcache_copy(
|
||||||
|
provider: str,
|
||||||
|
bsz: int,
|
||||||
|
block_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
|
||||||
|
num_kv_heads: int,
|
||||||
|
same_context_len: bool,
|
||||||
|
):
|
||||||
|
dtype = torch.float32
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
||||||
|
|
||||||
|
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
|
||||||
|
bsz,
|
||||||
|
num_kv_heads,
|
||||||
|
HEAD_DIM,
|
||||||
|
block_size,
|
||||||
|
max_seq_len // block_size,
|
||||||
|
same_context_len,
|
||||||
|
KV_SEQ_LEN,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
|
||||||
|
if provider == "torch_copy_func":
|
||||||
|
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
||||||
|
elif provider == "triton_copy_func":
|
||||||
|
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||||
|
elif provider == "cuda_copy_func":
|
||||||
|
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
|
||||||
|
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
|
||||||
|
fn = lambda: inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_kvcache_copy.run(save_path=".", print_data=True)
|
@ -4,6 +4,7 @@ from .flash_attention import (
|
|||||||
FlashAttentionNpuExtension,
|
FlashAttentionNpuExtension,
|
||||||
FlashAttentionXformersCudaExtension,
|
FlashAttentionXformersCudaExtension,
|
||||||
)
|
)
|
||||||
|
from .inference import InferenceOpsCudaExtension
|
||||||
from .layernorm import LayerNormCudaExtension
|
from .layernorm import LayerNormCudaExtension
|
||||||
from .moe import MoeCudaExtension
|
from .moe import MoeCudaExtension
|
||||||
from .optimizer import FusedOptimizerCudaExtension
|
from .optimizer import FusedOptimizerCudaExtension
|
||||||
@ -15,6 +16,7 @@ ALL_EXTENSIONS = [
|
|||||||
LayerNormCudaExtension,
|
LayerNormCudaExtension,
|
||||||
MoeCudaExtension,
|
MoeCudaExtension,
|
||||||
FusedOptimizerCudaExtension,
|
FusedOptimizerCudaExtension,
|
||||||
|
InferenceOpsCudaExtension,
|
||||||
ScaledMaskedSoftmaxCudaExtension,
|
ScaledMaskedSoftmaxCudaExtension,
|
||||||
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
|
ScaledUpperTriangleMaskedSoftmaxCudaExtension,
|
||||||
FlashAttentionDaoCudaExtension,
|
FlashAttentionDaoCudaExtension,
|
||||||
@ -28,6 +30,7 @@ __all__ = [
|
|||||||
"LayerNormCudaExtension",
|
"LayerNormCudaExtension",
|
||||||
"MoeCudaExtension",
|
"MoeCudaExtension",
|
||||||
"FusedOptimizerCudaExtension",
|
"FusedOptimizerCudaExtension",
|
||||||
|
"InferenceOpsCudaExtension",
|
||||||
"ScaledMaskedSoftmaxCudaExtension",
|
"ScaledMaskedSoftmaxCudaExtension",
|
||||||
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
|
"ScaledUpperTriangleMaskedSoftmaxCudaExtension",
|
||||||
"FlashAttentionDaoCudaExtension",
|
"FlashAttentionDaoCudaExtension",
|
||||||
|
15
extensions/csrc/cuda/colossal_inference_C_frontend.cpp
Normal file
15
extensions/csrc/cuda/colossal_inference_C_frontend.cpp
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void decode_kv_cache_memcpy(
|
||||||
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||||
|
torch::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||||
|
"Copy the GPU memory of kvcache during the decode stage.");
|
||||||
|
}
|
90
extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
Normal file
90
extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "type_shim.h"
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void decode_kv_cache_memcpy_kernel(
|
||||||
|
const scalar_t* __restrict__ key,
|
||||||
|
const scalar_t* __restrict__ value,
|
||||||
|
scalar_t* __restrict__ key_cache,
|
||||||
|
scalar_t* __restrict__ value_cache,
|
||||||
|
const int* __restrict__ sequence_lengths,
|
||||||
|
const int* __restrict__ block_tables,
|
||||||
|
const int num_heads,
|
||||||
|
const int head_size,
|
||||||
|
const int block_size,
|
||||||
|
const int key_stride,
|
||||||
|
const int value_stride,
|
||||||
|
const int block_table_stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const int seq_id = blockIdx.x;
|
||||||
|
const int seq_len = sequence_lengths[seq_id] - 1;
|
||||||
|
const int seq_id_in_block_table = seq_len / block_size;
|
||||||
|
const int block_offset = seq_len % block_size;
|
||||||
|
const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table];
|
||||||
|
const int hidden_size = num_heads * head_size;
|
||||||
|
|
||||||
|
if ( block_id < 0 ) {
|
||||||
|
return ;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||||
|
const int head_id = i / head_size;
|
||||||
|
const int head_offset = i % head_size;
|
||||||
|
const int key_src_id = seq_id * key_stride + i;
|
||||||
|
const int value_src_id = seq_id * value_stride + i;
|
||||||
|
const int target_src_id = block_id * hidden_size * block_size
|
||||||
|
+ head_id * block_size * head_size
|
||||||
|
+ block_offset * head_size + head_offset;
|
||||||
|
|
||||||
|
key_cache[target_src_id] = key[key_src_id];
|
||||||
|
value_cache[target_src_id] = value[value_src_id];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void decode_kv_cache_memcpy(
|
||||||
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||||
|
torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||||
|
torch::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
torch::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
|
{
|
||||||
|
int num_tokens = key.size(0);
|
||||||
|
int num_heads = key.size(1);
|
||||||
|
int head_size = key.size(2);
|
||||||
|
int block_size = key_cache.size(2);
|
||||||
|
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
int value_stride = value.stride(0);
|
||||||
|
int block_table_stride = block_tables.stride(0);
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
key.scalar_type(),
|
||||||
|
"decode_kv_cache_memcpy",
|
||||||
|
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
block_table_stride
|
||||||
|
);)
|
||||||
|
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
}
|
@ -24,6 +24,27 @@
|
|||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||||
|
switch (TYPE) { \
|
||||||
|
case at::ScalarType::Float: { \
|
||||||
|
using scalar_t = float; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Half: { \
|
||||||
|
using scalar_t = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: { \
|
||||||
|
using scalar_t = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
|
}
|
||||||
|
|
||||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||||
switch (TYPEIN) { \
|
switch (TYPEIN) { \
|
||||||
case at::ScalarType::Float: { \
|
case at::ScalarType::Float: { \
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from .base_extension import _Extension
|
||||||
from .cpp_extension import _CppExtension
|
from .cpp_extension import _CppExtension
|
||||||
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
|
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list
|
||||||
|
|
||||||
|
3
extensions/inference/__init__.py
Normal file
3
extensions/inference/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .inference_ops_cuda import InferenceOpsCudaExtension
|
||||||
|
|
||||||
|
__all__ = ["InferenceOpsCudaExtension"]
|
30
extensions/inference/inference_ops_cuda.py
Normal file
30
extensions/inference/inference_ops_cuda.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from ..cuda_extension import _CudaExtension
|
||||||
|
from ..utils import get_cuda_cc_flag
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceOpsCudaExtension(_CudaExtension):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(name="inference_ops_cuda")
|
||||||
|
|
||||||
|
def sources_files(self):
|
||||||
|
ret = [
|
||||||
|
self.csrc_abs_path(fname)
|
||||||
|
for fname in [
|
||||||
|
"cuda/colossal_inference_C_frontend.cpp",
|
||||||
|
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def include_dirs(self):
|
||||||
|
ret = [self.get_cuda_home_include()]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def cxx_flags(self):
|
||||||
|
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
|
||||||
|
return ["-O3"] + version_dependent_macros
|
||||||
|
|
||||||
|
def nvcc_flags(self):
|
||||||
|
extra_cuda_flags = ["-lineinfo"]
|
||||||
|
extra_cuda_flags.extend(get_cuda_cc_flag())
|
||||||
|
return ["-O3", "--use_fast_math"] + extra_cuda_flags
|
0
tests/test_infer/test_ops/__init__.py
Normal file
0
tests/test_infer/test_ops/__init__.py
Normal file
0
tests/test_infer/test_ops/cuda/__init__.py
Normal file
0
tests/test_infer/test_ops/cuda/__init__.py
Normal file
65
tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py
Normal file
65
tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
HEAD_DIM = 4
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bsz", [4, 7, 32])
|
||||||
|
@pytest.mark.parametrize("block_size", [16, 32, 64])
|
||||||
|
@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32])
|
||||||
|
@pytest.mark.parametrize("num_kv_heads", [16])
|
||||||
|
@pytest.mark.parametrize("same_context_len", [True, False])
|
||||||
|
def test_copy_kv_to_caches(
|
||||||
|
bsz: int,
|
||||||
|
block_size: int,
|
||||||
|
max_num_blocks_per_seq: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
same_context_len: bool,
|
||||||
|
):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
max_seq_len = block_size * max_num_blocks_per_seq
|
||||||
|
dtype = torch.float32
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables = prepare_data(
|
||||||
|
bsz,
|
||||||
|
num_kv_heads,
|
||||||
|
HEAD_DIM,
|
||||||
|
block_size,
|
||||||
|
max_num_blocks_per_seq,
|
||||||
|
same_context_len,
|
||||||
|
max_seq_len,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_k = new_k.squeeze(1) if new_k.dim() == 4 else new_k
|
||||||
|
new_v = new_v.squeeze(1) if new_v.dim() == 4 else new_v
|
||||||
|
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables)
|
||||||
|
|
||||||
|
past_kv_seq_len = kv_seq_lengths - 1
|
||||||
|
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||||
|
offsets_in_block = past_kv_seq_len % block_size
|
||||||
|
k_target = k_cache[target_block_ids, :, offsets_in_block, :]
|
||||||
|
k_source = new_k.squeeze()
|
||||||
|
v_target = v_cache[target_block_ids, :, offsets_in_block, :]
|
||||||
|
v_source = new_v.squeeze()
|
||||||
|
|
||||||
|
assert k_target.shape == k_source.shape
|
||||||
|
assert torch.equal(k_target, k_source)
|
||||||
|
assert v_target.shape == v_source.shape
|
||||||
|
assert torch.equal(v_target, v_source)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_copy_kv_to_caches(4, 32, 8, 16, True)
|
0
tests/test_infer/test_ops/triton/__init__.py
Normal file
0
tests/test_infer/test_ops/triton/__init__.py
Normal file
@ -2,7 +2,6 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.inference.modeling.layers.attention import copy_to_cache
|
|
||||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
from colossalai.kernel.triton import copy_kv_to_blocked_cache
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, mock_alloc_single_token
|
||||||
@ -108,69 +107,7 @@ def test_copy_kv_to_caches(
|
|||||||
assert torch.equal(k_target, k_source)
|
assert torch.equal(k_target, k_source)
|
||||||
assert v_target.shape == v_source.shape
|
assert v_target.shape == v_source.shape
|
||||||
assert torch.equal(v_target, v_source)
|
assert torch.equal(v_target, v_source)
|
||||||
# target_torch = k_cache_copy[target_block_ids, :, offsets_in_block, :]
|
|
||||||
# assert target_torch.shape == source.shape
|
|
||||||
# assert torch.equal(target_torch, source)
|
|
||||||
|
|
||||||
|
|
||||||
BATCH = 16
|
|
||||||
BLOCK_SIZE = 32
|
|
||||||
SAME_LEN = True
|
|
||||||
WARM_UPS = 10
|
|
||||||
REPS = 100
|
|
||||||
configs = [
|
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["KV_SEQ_LEN"],
|
|
||||||
x_vals=[2**i for i in range(8, 13)],
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=["torch_copy_func", "triton_copy_func"],
|
|
||||||
line_names=["torch_copy_func", "triton_copy_func"],
|
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
|
||||||
ylabel="ms",
|
|
||||||
plot_name=f"kvcache_copy_decoding_stage-batch-{BATCH}",
|
|
||||||
args={"bsz": BATCH, "block_size": 16, "max_seq_len": 8192, "num_kv_heads": 16, "same_context_len": True},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
|
||||||
def benchmark_kvcache_copy(
|
|
||||||
provider: str,
|
|
||||||
bsz: int,
|
|
||||||
block_size: int,
|
|
||||||
max_seq_len: int,
|
|
||||||
KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
|
|
||||||
num_kv_heads: int,
|
|
||||||
same_context_len: bool,
|
|
||||||
):
|
|
||||||
dtype = torch.float16
|
|
||||||
device = get_current_device()
|
|
||||||
|
|
||||||
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"
|
|
||||||
|
|
||||||
new_k, new_v, k_cache, v_cache, context_lengths, block_tables = prepare_data(
|
|
||||||
bsz,
|
|
||||||
num_kv_heads,
|
|
||||||
HEAD_DIM,
|
|
||||||
block_size,
|
|
||||||
max_seq_len // block_size,
|
|
||||||
same_context_len,
|
|
||||||
KV_SEQ_LEN,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
|
||||||
# TODO copy_to_cache needs to support copying both k and v at the same time in the future.
|
|
||||||
if provider == "torch_copy_func":
|
|
||||||
fn = lambda: copy_to_cache(new_k, k_cache, lengths=context_lengths, block_tables=block_tables, type="decoding")
|
|
||||||
if provider == "triton_copy_func":
|
|
||||||
fn = lambda: copy_kv_to_blocked_cache(new_k, new_v, k_cache, v_cache, context_lengths, block_tables)
|
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
|
||||||
return ms, min_ms, max_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_copy_kv_to_caches(4, 32, 8, 16, True)
|
test_copy_kv_to_caches(4, 32, 8, 16, True)
|
||||||
# benchmark_kvcache_copy.run(save_path=".", print_data=True)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user