[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)

* opt rms_norm

* fix bugs in rms_layernorm
This commit is contained in:
yuehuayingxueluo
2024-02-02 15:06:01 +08:00
committed by GitHub
parent 027aa1043f
commit 21ad4a27f9
7 changed files with 34 additions and 35 deletions

View File

@@ -95,6 +95,8 @@ def llama_model_forward(
)
sm_scale = 1.0 / (batch.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
@@ -107,13 +109,15 @@ def llama_model_forward(
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
)
if batch.is_prompts:
last_token_indexs = sequence_lengths.cumsum(dim=-1)
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
hidden_states = self.norm(hidden_states)
norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states, norm_output)
return hidden_states
@@ -131,6 +135,7 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
@@ -148,11 +153,12 @@ def llama_decoder_layer_forward(
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states, norm_output)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -171,7 +177,7 @@ def llama_decoder_layer_forward(
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
hidden_states = self.mlp(hidden_states, residual)
return hidden_states

View File

@@ -135,6 +135,8 @@ def llama_model_forward(
)
sm_scale = 1.0 / (batch.head_dim**0.5)
norm_output = torch.empty_like(hidden_states)
for layer_id, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
@@ -149,12 +151,14 @@ def llama_model_forward(
cos_sin=cos_sin,
fd_inter_tensor=batch.fd_inter_tensor,
output_tensor=output_tensor,
norm_output=norm_output,
sm_scale=sm_scale,
)
if batch.is_prompts:
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
hidden_states = self.norm(hidden_states)
norm_output = torch.empty_like(hidden_states)
hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
return hidden_states
@@ -174,6 +178,7 @@ def llama_decoder_layer_forward(
cos_sin: Tuple[torch.Tensor] = None,
fd_inter_tensor: FDIntermTensors = None,
output_tensor: torch.Tensor = None,
norm_output: torch.Tensor = None,
sm_scale: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""This function will replace the forward function of LlamaDecoderLayer.
@@ -191,11 +196,12 @@ def llama_decoder_layer_forward(
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
@@ -217,7 +223,7 @@ def llama_decoder_layer_forward(
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states