[feat] add use_cuda_kernel option

This commit is contained in:
Runyu Lu
2024-03-19 13:24:25 +08:00
parent 6e30248683
commit aabc9fb6aa
3 changed files with 11 additions and 2 deletions

View File

@@ -60,6 +60,7 @@ def llama_causal_lm_forward(
inputmetadata=inputmetadata,
k_caches=k_caches,
v_caches=v_caches,
use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
)
logits = torch.mm(hidden_states, self.lm_head.weight)
return logits
@@ -72,6 +73,7 @@ def llama_model_forward(
inputmetadata: InputMetaData,
k_caches: List[torch.Tensor] = None,
v_caches: List[torch.Tensor] = None,
use_cuda_kernel: Optional[bool] = True,
) -> torch.Tensor:
"""This function will replace the forward function of LlamaModel.
@@ -84,8 +86,7 @@ def llama_model_forward(
sequence_lengths = inputmetadata.sequence_lengths
batch_size = inputmetadata.batch_size
kv_seq_len = inputmetadata.kv_seq_len
# use_cuda_kernel = False
use_cuda_kernel = True
# NOTE: After testing, the performance of this configuration is relatively good. With updates
# and optimizations to the CUDA kernel implementation, a more detailed analysis of this configuration's
# selection should be conducted.