mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[Inference]Add CUDA KVCache Kernel (#5406)
* add cuda KVCache kernel * annotation benchmark_kvcache_copy * add use cuda * fix import path * move benchmark scripts to example/ * rm benchmark codes in test_kv_cache_memcpy.py * rm redundancy codes * rm redundancy codes * pr was modified according to the review
This commit is contained in:
@@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
decoding_fused_rotary_embedding,
|
||||
@@ -22,6 +23,8 @@ from colossalai.kernel.triton import (
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
@@ -74,6 +77,12 @@ def llama_model_forward(
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
batch_size = batch.current_batch_size
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
use_cuda_kernel = True
|
||||
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
||||
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
||||
# selection should be conducted.
|
||||
if batch_size >= 32 and kv_seq_len > 512:
|
||||
use_cuda_kernel = False
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
@@ -107,6 +116,7 @@ def llama_model_forward(
|
||||
output_tensor=output_tensor,
|
||||
norm_output=norm_output,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
@@ -134,6 +144,7 @@ def llama_decoder_layer_forward(
|
||||
output_tensor: torch.Tensor = None,
|
||||
norm_output: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||
|
||||
@@ -153,6 +164,7 @@ def llama_decoder_layer_forward(
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
"""
|
||||
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
||||
@@ -169,6 +181,7 @@ def llama_decoder_layer_forward(
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
@@ -252,6 +265,7 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
use_cuda_kernel: bool = True,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -268,6 +282,7 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
storing intermediate values in flash-decoding. Defaults to None.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||
"""
|
||||
|
||||
if self.num_heads != self.num_key_value_heads:
|
||||
@@ -283,7 +298,6 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
@@ -300,17 +314,23 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
if use_cuda_kernel:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
)
|
||||
else:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
|
Reference in New Issue
Block a user