[Inference]Add CUDA KVCache Kernel (#5406)

* add cuda KVCache kernel

* annotation benchmark_kvcache_copy

* add use cuda

* fix import path

* move benchmark scripts to example/

* rm benchmark codes in test_kv_cache_memcpy.py

* rm redundancy codes

* rm redundancy codes

* pr was modified according to the review
This commit is contained in:
yuehuayingxueluo
2024-02-28 14:36:50 +08:00
committed by GitHub
parent 19061188c3
commit 600881a8ea
15 changed files with 348 additions and 75 deletions

View File

@@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import (
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
context_attention_unpadded,
decoding_fused_rotary_embedding,
@@ -22,6 +23,8 @@ from colossalai.kernel.triton import (
)
from colossalai.logging import get_dist_logger
inference_ops = InferenceOpsLoader().load()
logger = get_dist_logger(__name__)
try:
@@ -74,6 +77,12 @@ def llama_model_forward(
sequence_lengths = batch.get_sequence_lengths()
batch_size = batch.current_batch_size
kv_seq_len = sequence_lengths.max().item()
use_cuda_kernel = True
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
hidden_states = self.embed_tokens(input_ids)
@@ -107,6 +116,7 @@ def llama_model_forward(
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
)
if batch.is_prompts:
@@ -134,6 +144,7 @@ def llama_decoder_layer_forward(
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
@@ -153,6 +164,7 @@ def llama_decoder_layer_forward(
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
@@ -169,6 +181,7 @@ def llama_decoder_layer_forward(
fd_inter_tensor=fd_inter_tensor,
output_tensor=output_tensor,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
)
# Fully Connected
@@ -252,6 +265,7 @@ class NopadLlamaAttention(LlamaAttention):
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
sm_scale: int = None,
use_cuda_kernel: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
@@ -268,6 +282,7 @@ class NopadLlamaAttention(LlamaAttention):
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""
if self.num_heads != self.num_key_value_heads:
@@ -283,7 +298,6 @@ class NopadLlamaAttention(LlamaAttention):
)
block_size = k_cache.size(-2)
if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
@@ -300,17 +314,23 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale,
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
if use_cuda_kernel:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
else:
decoding_fused_rotary_embedding(
query_states,
key_states,
value_states,
cos_sin[0],
cos_sin[1],
k_cache,
v_cache,
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,