[feat] cuda graph support and refactor non-functional api

This commit is contained in:
Runyu Lu
2024-03-08 14:19:35 +08:00
parent 593a72e4d5
commit cefaeb5fdd
5 changed files with 281 additions and 43 deletions

View File

@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import (
LlamaModel,
)
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InputMetaData
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import (
@@ -36,10 +36,12 @@ except ImportError:
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchBucket = None,
input_tokens_ids: torch.Tensor,
output_tensor: torch.Tensor,
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
) -> torch.Tensor:
"""This function will replace the forward function of LlamaForCausalLM.
Args:
@@ -51,7 +53,9 @@ def llama_causal_lm_forward(
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
hidden_states = llama_model_forward(
self.model,
batch=batch,
input_tokens_ids=input_tokens_ids,
output_tensor=output_tensor,
inputmetadata=inputmetadata,
k_caches=k_caches,
v_caches=v_caches,
)
@@ -61,10 +65,12 @@ def llama_causal_lm_forward(
def llama_model_forward(
self: LlamaModel,
batch: BatchBucket = None,
input_tokens_ids: torch.Tensor,
output_tensor: torch.Tensor,
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
):
) -> torch.Tensor:
"""This function will replace the forward function of LlamaModel.
Args:
@@ -72,11 +78,10 @@ def llama_model_forward(
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
"""
input_ids = batch.get_1D_inputs()
block_tables = batch.get_block_table_tensor()
sequence_lengths = batch.get_sequence_lengths()
batch_size = batch.current_batch_size
kv_seq_len = sequence_lengths.max().item()
block_tables = inputmetadata.block_tables
sequence_lengths = inputmetadata.sequence_lengths
batch_size = inputmetadata.batch_size
kv_seq_len = inputmetadata.kv_seq_len
use_cuda_kernel = True
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
@@ -84,21 +89,13 @@ def llama_model_forward(
if batch_size >= 32 and kv_seq_len > 512:
use_cuda_kernel = False
hidden_states = self.embed_tokens(input_ids)
hidden_states = self.embed_tokens(input_tokens_ids)
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
if batch.is_prompts:
output_tensor = torch.zeros(
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
else:
output_tensor = torch.zeros(
(batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
)
sm_scale = 1.0 / (batch.head_dim**0.5)
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
norm_output = None
residual = None
for layer_id, decoder_layer in enumerate(self.layers):
@@ -108,22 +105,22 @@ def llama_model_forward(
block_tables=block_tables,
k_cache=k_caches[layer_id],
v_cache=v_caches[layer_id],
is_prompts=batch.is_prompts,
is_prompts=inputmetadata.is_prompts,
sequence_lengths=sequence_lengths,
kv_seq_len=kv_seq_len,
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
fd_inter_tensor=inputmetadata.fd_inter_tensor,
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
use_cuda_kernel=use_cuda_kernel,
)
if batch.is_prompts:
if inputmetadata.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
norm_output = torch.empty_like(hidden_states)
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
return hidden_states