mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference/opt]Optimize the mid tensor of RMS Norm (#5350)
* opt rms_norm * fix bugs in rms_layernorm
This commit is contained in:
@@ -95,6 +95,8 @@ def llama_model_forward(
|
||||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
@@ -107,13 +109,15 @@ def llama_model_forward(
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
norm_output=norm_output,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
hidden_states = self.norm(hidden_states)
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
hidden_states = self.norm(hidden_states, norm_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -131,6 +135,7 @@ def llama_decoder_layer_forward(
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
norm_output: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||
@@ -148,11 +153,12 @@ def llama_decoder_layer_forward(
|
||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
|
||||
storing intermediate values in flash-decoding. Defaults to None.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.input_layernorm(hidden_states, norm_output)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -171,7 +177,7 @@ def llama_decoder_layer_forward(
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states, norm_output)
|
||||
hidden_states = self.mlp(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
@@ -135,6 +135,8 @@ def llama_model_forward(
|
||||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
@@ -149,12 +151,14 @@ def llama_model_forward(
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
norm_output=norm_output,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
|
||||
hidden_states = self.norm(hidden_states)
|
||||
norm_output = torch.empty_like(hidden_states)
|
||||
hidden_states = self.norm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -174,6 +178,7 @@ def llama_decoder_layer_forward(
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
norm_output: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||
@@ -191,11 +196,12 @@ def llama_decoder_layer_forward(
|
||||
cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
|
||||
fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
|
||||
output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
|
||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.input_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
@@ -217,7 +223,7 @@ def llama_decoder_layer_forward(
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states.reshape(-1, hidden_states.shape[-1]), norm_output)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
Reference in New Issue
Block a user