feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)

This commit is contained in:
Steve Luo
2024-03-08 16:21:12 +08:00
committed by GitHub
parent 2b28b54ac6
commit f7aecc0c6b
8 changed files with 244 additions and 49 deletions

View File

@@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
)
from colossalai.inference.batch_bucket import BatchBucket
@@ -19,6 +20,7 @@ from colossalai.kernel.triton import (
decoding_fused_rotary_embedding,
flash_decoding_attention,
get_xine_cache,
rms_layernorm,
rotary_embedding,
)
from colossalai.logging import get_dist_logger
@@ -124,7 +126,7 @@ def llama_model_forward(
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
residual = residual[last_token_indexs - 1].contiguous()
norm_output = torch.empty_like(hidden_states)
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
return hidden_states
@@ -167,7 +169,7 @@ def llama_decoder_layer_forward(
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
"""
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -185,12 +187,32 @@ def llama_decoder_layer_forward(
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def llama_rmsnorm_forward(
self: LlamaRMSNorm,
hidden_states: torch.Tensor,
norm_output: torch.Tensor,
residual: torch.Tensor = None,
use_cuda_kernel: bool = True,
):
if use_cuda_kernel:
if residual is not None:
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
return hidden_states, residual
if norm_output is None:
norm_output = torch.empty_like(hidden_states)
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
return norm_output, hidden_states
else:
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
class NopadLlamaAttention(LlamaAttention):
def __init__(
self,