mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[feat] cuda graph support and refactor non-functional api
This commit is contained in:
@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
LlamaModel,
|
||||
)
|
||||
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
@@ -36,10 +36,12 @@ except ImportError:
|
||||
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
batch: BatchBucket = None,
|
||||
input_tokens_ids: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
inputmetadata: InputMetaData,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""This function will replace the forward function of LlamaForCausalLM.
|
||||
|
||||
Args:
|
||||
@@ -51,7 +53,9 @@ def llama_causal_lm_forward(
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
hidden_states = llama_model_forward(
|
||||
self.model,
|
||||
batch=batch,
|
||||
input_tokens_ids=input_tokens_ids,
|
||||
output_tensor=output_tensor,
|
||||
inputmetadata=inputmetadata,
|
||||
k_caches=k_caches,
|
||||
v_caches=v_caches,
|
||||
)
|
||||
@@ -61,10 +65,12 @@ def llama_causal_lm_forward(
|
||||
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
batch: BatchBucket = None,
|
||||
input_tokens_ids: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
inputmetadata: InputMetaData,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""This function will replace the forward function of LlamaModel.
|
||||
|
||||
Args:
|
||||
@@ -72,11 +78,10 @@ def llama_model_forward(
|
||||
k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
|
||||
"""
|
||||
input_ids = batch.get_1D_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
batch_size = batch.current_batch_size
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
block_tables = inputmetadata.block_tables
|
||||
sequence_lengths = inputmetadata.sequence_lengths
|
||||
batch_size = inputmetadata.batch_size
|
||||
kv_seq_len = inputmetadata.kv_seq_len
|
||||
use_cuda_kernel = True
|
||||
# NOTE: After testing, the performance of this configuration is relatively good. With updates
|
||||
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
|
||||
@@ -84,21 +89,13 @@ def llama_model_forward(
|
||||
if batch_size >= 32 and kv_seq_len > 512:
|
||||
use_cuda_kernel = False
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts)
|
||||
|
||||
if batch.is_prompts:
|
||||
output_tensor = torch.zeros(
|
||||
(sequence_lengths.sum().item(), batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
else:
|
||||
output_tensor = torch.zeros(
|
||||
(batch_size, batch.num_heads * batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
sm_scale = 1.0 / (inputmetadata.head_dim**0.5)
|
||||
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
norm_output = None
|
||||
residual = None
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
@@ -108,22 +105,22 @@ def llama_model_forward(
|
||||
block_tables=block_tables,
|
||||
k_cache=k_caches[layer_id],
|
||||
v_cache=v_caches[layer_id],
|
||||
is_prompts=batch.is_prompts,
|
||||
is_prompts=inputmetadata.is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
fd_inter_tensor=inputmetadata.fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
norm_output=norm_output,
|
||||
sm_scale=sm_scale,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
if inputmetadata.is_prompts:
|
||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
residual = residual[last_token_indexs - 1].contiguous()
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
|
||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
||||
|
||||
return hidden_states
|
||||
|
Reference in New Issue
Block a user