mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)
* fused the gate and up proj in mlp * fix code styles * opt auto_grad * rollback test_inference_engine.py * modifications based on the review feedback. * fix bugs in flash attn * Change reshape to view * fix test_rmsnorm_triton.py
This commit is contained in:
@@ -32,7 +32,6 @@ except ImportError:
|
||||
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
batch: BatchInfo = None,
|
||||
@@ -58,7 +57,6 @@ def llama_causal_lm_forward(
|
||||
return logits
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
batch: BatchInfo = None,
|
||||
@@ -120,7 +118,6 @@ def llama_model_forward(
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_decoder_layer_forward(
|
||||
self: LlamaDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -139,7 +136,7 @@ def llama_decoder_layer_forward(
|
||||
"""This function will replace the forward function of LlamaDecoderLayer.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
@@ -154,8 +151,8 @@ def llama_decoder_layer_forward(
|
||||
norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None.
|
||||
sm_scale (int, optional): Used for flash attention. Defaults to None.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states, norm_output)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
@@ -240,7 +237,6 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
return attn_layer
|
||||
|
||||
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -258,8 +254,8 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
|
||||
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj.
|
||||
block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
|
||||
storing mapping of token_position_id -> block_id. Defaults to None.
|
||||
k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
|
||||
@@ -321,7 +317,7 @@ class NopadLlamaAttention(LlamaAttention):
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(-1, self.hidden_size)
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
|
||||
|
||||
return attn_output
|
||||
@@ -345,9 +341,10 @@ class NopadLlamaMLP(LlamaMLP):
|
||||
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
|
||||
self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
|
||||
self.gate_up_weight = Parameter(torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0), requires_grad=False)
|
||||
self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
|
||||
self.gate_proj = None
|
||||
self.up_proj = None
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
|
||||
@@ -371,15 +368,14 @@ class NopadLlamaMLP(LlamaMLP):
|
||||
|
||||
return mlp_layer
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
|
||||
residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in down_proj.
|
||||
"""
|
||||
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
|
||||
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
|
||||
up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
|
||||
tmp_out = act_out * up_proj_out
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True)
|
||||
tmp_out = act_out * gate_up_proj_out[1]
|
||||
return torch.addmm(residual, tmp_out, self.down_proj.weight)
|
||||
|
Reference in New Issue
Block a user