[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)

* fused the gate and up proj in mlp

* fix code styles

* opt auto_grad

* rollback test_inference_engine.py

* modifications based on the review feedback.

* fix bugs in flash attn

* Change reshape to view

* fix test_rmsnorm_triton.py
This commit is contained in:
yuehuayingxueluo
2024-02-06 19:38:25 +08:00
committed by GitHub
parent 1dedb57747
commit 35382a7fbf
10 changed files with 484 additions and 50 deletions

View File

@@ -32,7 +32,6 @@ except ImportError:
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
@torch.no_grad()
def llama_causal_lm_forward(
self: LlamaForCausalLM,
batch: BatchInfo = None,
@@ -58,7 +57,6 @@ def llama_causal_lm_forward(
return logits
@torch.no_grad()
def llama_model_forward(
self: LlamaModel,
batch: BatchInfo = None,
@@ -120,7 +118,6 @@ def llama_model_forward(
return hidden_states
@torch.no_grad()
def llama_decoder_layer_forward(
self: LlamaDecoderLayer,
hidden_states: torch.Tensor,
@@ -139,7 +136,7 @@ def llama_decoder_layer_forward(
"""This function will replace the forward function of LlamaDecoderLayer.
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
@@ -154,8 +151,8 @@ def llama_decoder_layer_forward(
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
sm_scale (int, optional): Used for flash attention. Defaults to None.
"""
residual = hidden_states
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states, norm_output)
# Self Attention
hidden_states = self.self_attn(
@@ -240,7 +237,6 @@ class NopadLlamaAttention(LlamaAttention):
return attn_layer
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
@torch.no_grad()
def forward(
self,
hidden_states: torch.Tensor,
@@ -258,8 +254,8 @@ class NopadLlamaAttention(LlamaAttention):
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
storing mapping of token_position_id -> block_id. Defaults to None.
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
@@ -321,7 +317,7 @@ class NopadLlamaAttention(LlamaAttention):
sm_scale=sm_scale,
)
attn_output = attn_output.reshape(-1, self.hidden_size)
attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
return attn_output
@@ -345,9 +341,10 @@ class NopadLlamaMLP(LlamaMLP):
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
"""
super().__init__(config)
self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False)
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
self.gate_proj = None
self.up_proj = None
@staticmethod
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
@@ -371,15 +368,14 @@ class NopadLlamaMLP(LlamaMLP):
return mlp_layer
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
"""
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
tmp_out = act_out * up_proj_out
hidden_states = hidden_states.expand(2, -1, -1)
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
tmp_out = act_out * gate_up_proj_out[1]
return torch.addmm(residual, tmp_out, self.down_proj.weight)