From 249644c23b0402ccf9d0908f13ed15b41b95145f Mon Sep 17 00:00:00 2001
From: yuehuayingxueluo <867460659@qq.com>
Date: Thu, 1 Feb 2024 15:49:39 +0800
Subject: [PATCH] =?UTF-8?q?[Inference]Repalce=20Attention=20layer=20and=20?=
 =?UTF-8?q?MLP=20layer=20by=20shardformer=20to=20optimize=20the=20weight?=
 =?UTF-8?q?=20transpose=20operation=EF=BC=8Cadd=20fused=5Fqkv=20and=20fuse?=
 =?UTF-8?q?d=20linear=5Fadd=20(#5340)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* add fused qkv

* replace attn and mlp by shardformer

* fix bugs in mlp

* add docstrings

* fix test_inference_engine.py

* add optimize unbind

* add fused_addmm

* rm squeeze(1)

* refactor codes

* fix ci bugs

* rename ShardFormerLlamaMLP and ShardFormerLlamaAttention

* Removed the dependency on LlamaFlashAttention2

* rollback test_inference_engine.py
---
 .../modeling/models/nopadding_llama.py        | 306 +++++++++++++----
 .../modeling/models/padding_llama.py          | 321 ++++++++++++------
 .../modeling/policy/nopadding_llama.py        |  60 ++--
 .../modeling/policy/padding_llama.py          | 135 +-------
 colossalai/kernel/triton/flash_decoding.py    |  10 +-
 examples/inference/run_benchmark.sh           |  14 +-
 tests/test_infer_ops/triton/kernel_utils.py   |   1 +
 .../triton/test_decoding_attn.py              |   4 +-
 8 files changed, 510 insertions(+), 341 deletions(-)

diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py
index 569c5f05a..6b108cd4d 100644
--- a/colossalai/inference/modeling/models/nopadding_llama.py
+++ b/colossalai/inference/modeling/models/nopadding_llama.py
@@ -2,8 +2,10 @@
 from typing import List, Optional, Tuple
 
 import torch
+from torch.nn import Parameter
 from transformers.models.llama.modeling_llama import (
     LlamaAttention,
+    LlamaConfig,
     LlamaDecoderLayer,
     LlamaForCausalLM,
     LlamaMLP,
@@ -39,6 +41,14 @@ def llama_causal_lm_forward(
     k_caches: List[torch.Tensor] = None,
     v_caches: List[torch.Tensor] = None,
 ):
+    """This function will replace the forward function of LlamaForCausalLM.
+
+    Args:
+        batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
+        k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
+        v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
+    """
+
     # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
     hidden_states = llama_model_forward(
         self.model,
@@ -46,7 +56,7 @@ def llama_causal_lm_forward(
         k_caches=k_caches,
         v_caches=v_caches,
     )
-    logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1))
+    logits = torch.mm(hidden_states, self.lm_head.weight)
     return logits
 
 
@@ -57,6 +67,13 @@ def llama_model_forward(
     k_caches: List[torch.Tensor] = None,
     v_caches: List[torch.Tensor] = None,
 ):
+    """This function will replace the forward function of LlamaModel.
+
+    Args:
+        batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
+        k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
+        v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
+    """
     input_ids = batch.get_1D_inputs()
     block_tables = batch.get_block_table_tensor()
 
@@ -74,7 +91,7 @@ def llama_model_forward(
         )
     else:
         output_tensor = torch.zeros(
-            (batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
+            (batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
         )
     sm_scale = 1.0 / (batch.head_dim**0.5)
 
@@ -116,12 +133,30 @@ def llama_decoder_layer_forward(
     output_tensor: torch.Tensor = None,
     sm_scale: int = None,
 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+    """This function will replace the forward function of LlamaDecoderLayer.
+
+    Args:
+        hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
+        block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
+            storing mapping of token_position_id -> block_id. Defaults to None.
+        k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
+        v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
+        is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
+        sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
+        kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
+        cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
+        fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
+            storing intermediate values in flash-decoding. Defaults to None.
+        output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
+        sm_scale (int, optional): Used for flash attention. Defaults to None.
+    """
     residual = hidden_states
 
     hidden_states = self.input_layernorm(hidden_states)
     # Self Attention
     hidden_states = self.self_attn(
         hidden_states=hidden_states,
+        residual=residual,
         block_tables=block_tables,
         k_cache=k_cache,
         v_cache=v_cache,
@@ -134,88 +169,213 @@ def llama_decoder_layer_forward(
         sm_scale=sm_scale,
     )
 
-    hidden_states = residual + hidden_states
-
     # Fully Connected
     residual = hidden_states
     hidden_states = self.post_attention_layernorm(hidden_states)
-    hidden_states = self.mlp(hidden_states)
-    hidden_states = residual + hidden_states
+    hidden_states = self.mlp(hidden_states, residual)
 
     return hidden_states
 
 
-# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
-@torch.no_grad()
-def llama_attn_forward(
-    self: LlamaAttention,
-    hidden_states: torch.Tensor,
-    block_tables: torch.Tensor = None,
-    k_cache: torch.Tensor = None,
-    v_cache: torch.Tensor = None,
-    is_prompts: bool = True,
-    sequence_lengths: torch.Tensor = None,
-    kv_seq_len: int = 0,
-    cos_sin: Tuple[torch.Tensor] = None,
-    fd_inter_tensor: FDIntermTensors = None,
-    output_tensor: torch.Tensor = None,
-    sm_scale: int = None,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-    query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim)
-    key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view(
-        -1, self.num_key_value_heads, self.head_dim
-    )
-    value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view(
-        -1, self.num_key_value_heads, self.head_dim
-    )
+class NopadLlamaAttention(LlamaAttention):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        layer_idx: Optional[int] = None,
+        attn_qproj_w: torch.Tensor = None,
+        attn_kproj_w: torch.Tensor = None,
+        attn_vproj_w: torch.Tensor = None,
+        attn_oproj_w: torch.Tensor = None,
+    ):
+        """This layer will replace the LlamaAttention.
 
-    rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
+        Args:
+            config (LlamaConfig): Holding the Llama model config.
+            layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
+            attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
+            attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
+            attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
+            attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
+        """
+        super().__init__(config, layer_idx)
+        self.q_proj.weight = Parameter(attn_qproj_w, requires_grad=False)
+        self.k_proj.weight = Parameter(attn_kproj_w, requires_grad=False)
+        self.v_proj.weight = Parameter(attn_vproj_w, requires_grad=False)
+        self.o_proj.weight = Parameter(attn_oproj_w, requires_grad=False)
+        if self.num_heads == self.num_key_value_heads:
+            qkv_weight_list = [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]
+            self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
+            self.q_proj = None
+            self.k_proj = None
+            self.v_proj = None
 
-    block_size = k_cache.size(-2)
+    @staticmethod
+    def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
+        """Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
 
-    if is_prompts:
-        attn_output = context_attention_unpadded(
-            q=query_states,
-            k=key_states,
-            v=value_states,
-            k_cache=k_cache,
-            v_cache=v_cache,
-            context_lengths=sequence_lengths,
-            block_tables=block_tables,
-            block_size=block_size,
-            output=output_tensor,
-            max_seq_len=kv_seq_len,
-            sm_scale=sm_scale,
+        Args:
+            module (LlamaAttention): The origin LlamaAttention layer.
+        """
+        config = module.config
+        layer_idx = module.layer_idx
+
+        attn_qproj_w = module.q_proj.weight.transpose(0, 1)
+        attn_kproj_w = module.k_proj.weight.transpose(0, 1)
+        attn_vproj_w = module.v_proj.weight.transpose(0, 1)
+        attn_oproj_w = module.o_proj.weight.transpose(0, 1)
+
+        attn_layer = NopadLlamaAttention(
+            config=config,
+            layer_idx=layer_idx,
+            attn_qproj_w=attn_qproj_w,
+            attn_kproj_w=attn_kproj_w,
+            attn_vproj_w=attn_vproj_w,
+            attn_oproj_w=attn_oproj_w,
         )
-    else:
-        copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
-        copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
-        attn_output = flash_decoding_attention(
-            q=query_states,
-            k_cache=k_cache,
-            v_cache=v_cache,
-            kv_seq_len=sequence_lengths,
-            block_tables=block_tables,
-            block_size=block_size,
-            max_seq_len_in_batch=kv_seq_len,
-            output=output_tensor,
-            mid_output=fd_inter_tensor.mid_output,
-            mid_output_lse=fd_inter_tensor.mid_output_lse,
-            sm_scale=sm_scale,
+
+        return attn_layer
+
+    # Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
+    @torch.no_grad()
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        residual: torch.Tensor,
+        block_tables: torch.Tensor = None,
+        k_cache: torch.Tensor = None,
+        v_cache: torch.Tensor = None,
+        is_prompts: bool = True,
+        sequence_lengths: torch.Tensor = None,
+        kv_seq_len: int = 0,
+        cos_sin: Tuple[torch.Tensor] = None,
+        fd_inter_tensor: FDIntermTensors = None,
+        output_tensor: torch.Tensor = None,
+        sm_scale: int = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """
+        Args:
+            hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
+            residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in out_proj.
+            block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
+                storing mapping of token_position_id -> block_id. Defaults to None.
+            k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
+            v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
+            is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
+            sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
+            kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
+            cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
+            fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
+                storing intermediate values in flash-decoding. Defaults to None.
+            output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
+            sm_scale (int, optional): Used for flash attention. Defaults to None.
+        """
+
+        if self.num_heads != self.num_key_value_heads:
+            query_states = torch.mm(hidden_states, self.q_proj.weight).view(-1, self.num_heads, self.head_dim)
+            key_states = torch.mm(hidden_states, self.k_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
+            value_states = torch.mm(hidden_states, self.v_proj.weight).view(-1, self.num_key_value_heads, self.head_dim)
+        else:
+            # fused qkv
+            token_nums = hidden_states.size(0)
+            hidden_states = hidden_states.expand(3, -1, -1)
+            query_states, key_states, value_states = (
+                torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
+            )
+
+        rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
+
+        block_size = k_cache.size(-2)
+
+        if is_prompts:
+            attn_output = context_attention_unpadded(
+                q=query_states,
+                k=key_states,
+                v=value_states,
+                k_cache=k_cache,
+                v_cache=v_cache,
+                context_lengths=sequence_lengths,
+                block_tables=block_tables,
+                block_size=block_size,
+                output=output_tensor,
+                max_seq_len=kv_seq_len,
+                sm_scale=sm_scale,
+            )
+        else:
+            copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
+            copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
+            attn_output = flash_decoding_attention(
+                q=query_states,
+                k_cache=k_cache,
+                v_cache=v_cache,
+                kv_seq_len=sequence_lengths,
+                block_tables=block_tables,
+                block_size=block_size,
+                max_seq_len_in_batch=kv_seq_len,
+                output=output_tensor,
+                mid_output=fd_inter_tensor.mid_output,
+                mid_output_lse=fd_inter_tensor.mid_output_lse,
+                sm_scale=sm_scale,
+            )
+
+        attn_output = attn_output.reshape(-1, self.hidden_size)
+        attn_output = torch.addmm(residual, attn_output, self.o_proj.weight)
+
+        return attn_output
+
+
+# NOTE This will cause the result to be different from the transformer in some cases.
+class NopadLlamaMLP(LlamaMLP):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        mlp_gproj_w: torch.Tensor = None,
+        mlp_uproj_w: torch.Tensor = None,
+        mlp_dproj_w: torch.Tensor = None,
+    ):
+        """This layer will replace the LlamaAttention.
+
+        Args:
+            config (LlamaConfig): Holding the Llama model config.
+            mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
+            mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
+            mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
+        """
+        super().__init__(config)
+        self.gate_proj.weight = Parameter(mlp_gproj_w, requires_grad=False)
+        self.up_proj.weight = Parameter(mlp_uproj_w, requires_grad=False)
+        self.down_proj.weight = Parameter(mlp_dproj_w, requires_grad=False)
+
+    @staticmethod
+    def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
+        """Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
+
+        Args:
+            module (LlamaMLP): The origin LlamaMLP layer.
+        """
+        config = module.config
+
+        mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
+        mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
+        mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
+
+        mlp_layer = NopadLlamaMLP(
+            config=config,
+            mlp_gproj_w=mlp_gproj_w,
+            mlp_uproj_w=mlp_uproj_w,
+            mlp_dproj_w=mlp_dproj_w,
         )
-        attn_output = attn_output.squeeze(1)
 
-    attn_output = attn_output.view(-1, self.num_heads, self.head_dim)
-    attn_output = attn_output.reshape(-1, self.hidden_size)
-    attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1))
+        return mlp_layer
 
-    return attn_output
-
-
-@torch.no_grad()
-def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor):
-    gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1))
-    act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
-    up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1))
-    tmp_out = act_out * up_proj_out
-    return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1))
+    @torch.no_grad()
+    def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`.
+            residual (torch.Tensor): shape `(token_num, embed_dim)`, used to be added to hidden_states in down_proj.
+        """
+        gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight)
+        act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
+        up_proj_out = torch.mm(hidden_states, self.up_proj.weight)
+        tmp_out = act_out * up_proj_out
+        return torch.addmm(residual, tmp_out, self.down_proj.weight)
diff --git a/colossalai/inference/modeling/models/padding_llama.py b/colossalai/inference/modeling/models/padding_llama.py
index 63a8d3673..51d718a53 100644
--- a/colossalai/inference/modeling/models/padding_llama.py
+++ b/colossalai/inference/modeling/models/padding_llama.py
@@ -2,7 +2,13 @@
 from typing import List, Optional, Tuple
 
 import torch
-from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
+from transformers.models.llama.modeling_llama import (
+    LlamaAttention,
+    LlamaConfig,
+    LlamaDecoderLayer,
+    LlamaForCausalLM,
+    LlamaModel,
+)
 
 from colossalai.inference.flash_decoding_utils import FDIntermTensors
 from colossalai.inference.modeling.layers.attention import PagedAttention
@@ -53,6 +59,14 @@ def llama_causal_lm_forward(
     k_caches: List[torch.Tensor] = None,
     v_caches: List[torch.Tensor] = None,
 ):
+    """This function will replace the forward function of LlamaForCausalLM.
+
+    Args:
+        batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None.
+        k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
+        v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
+    """
+
     # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
     hidden_states = llama_model_forward(
         self.model,
@@ -71,6 +85,13 @@ def llama_model_forward(
     k_caches: List[torch.Tensor] = None,
     v_caches: List[torch.Tensor] = None,
 ):
+    """This function will replace the forward function of LlamaModel.
+
+    Args:
+        batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None.
+        k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None.
+        v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None.
+    """
     input_ids = batch.get_batch_inputs()
     block_tables = batch.get_block_table_tensor()
     attention_mask = batch.get_attn_mask()
@@ -110,7 +131,7 @@ def llama_model_forward(
         )
     else:
         output_tensor = torch.zeros(
-            (batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
+            (batch_size, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
         )
     sm_scale = 1.0 / (batch.head_dim**0.5)
 
@@ -131,7 +152,8 @@ def llama_model_forward(
             sm_scale=sm_scale,
         )
 
-    hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
+    if batch.is_prompts:
+        hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
     hidden_states = self.norm(hidden_states)
 
     return hidden_states
@@ -154,6 +176,23 @@ def llama_decoder_layer_forward(
     output_tensor: torch.Tensor = None,
     sm_scale: int = None,
 ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+    """This function will replace the forward function of LlamaDecoderLayer.
+
+    Args:
+        hidden_states (torch.Tensor): _description_
+        position_ids (torch.LongTensor), The position ids of input sequences.
+        block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
+            storing mapping of token_position_id -> block_id. Defaults to None.
+        k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
+        v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
+        is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
+        sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
+        kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
+        cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
+        fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for storing intermediate values in flash-decoding. Defaults to None.
+        output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
+        sm_scale (int, optional): Used for flash attention. Defaults to None.
+    """
     residual = hidden_states
 
     hidden_states = self.input_layernorm(hidden_states)
@@ -185,108 +224,192 @@ def llama_decoder_layer_forward(
     return hidden_states
 
 
-# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
-@torch.no_grad()
-def llama_attn_forward(
-    self: LlamaAttention,
-    hidden_states: torch.Tensor,
-    position_ids: torch.LongTensor,
-    block_tables: torch.Tensor = None,
-    k_cache: torch.Tensor = None,
-    v_cache: torch.Tensor = None,
-    is_prompts: bool = True,
-    sequence_lengths: torch.Tensor = None,
-    attention_mask: torch.Tensor = None,
-    kv_seq_len: int = 0,
-    cos_sin: Tuple[torch.Tensor] = None,
-    fd_inter_tensor: FDIntermTensors = None,
-    output_tensor: torch.Tensor = None,
-    sm_scale: int = None,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-    bsz, q_len, _ = hidden_states.size()
+class PadLlamaAttention(LlamaAttention):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        layer_idx: Optional[int] = None,
+        attn_qproj_w: torch.nn.Parameter = None,
+        attn_kproj_w: torch.nn.Parameter = None,
+        attn_vproj_w: torch.nn.Parameter = None,
+        attn_oproj_w: torch.nn.Parameter = None,
+    ):
+        """This layer will replace the LlamaAttention.
 
-    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
-    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
-    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+        Args:
+            config (LlamaConfig): Holding the Llama model config.
+            layer_idx (Optional[int], optional): The decode layer id of this attention layer. Defaults to None.
+            attn_qproj_w (torch.nn.Parameter, optional): The q_proj weight. Defaults to None.
+            attn_kproj_w (torch.nn.Parameter, optional): The k_proj weight. Defaults to None.
+            attn_vproj_w (torch.nn.Parameter, optional): The v_proj weight. Defaults to None.
+            attn_oproj_w (torch.nn.Parameter, optional): The o_proj weight. Defaults to None.
+        """
+        super().__init__(config, layer_idx)
+        self.q_proj.weight = attn_qproj_w
+        self.k_proj.weight = attn_kproj_w
+        self.v_proj.weight = attn_vproj_w
+        self.o_proj.weight = attn_oproj_w
 
-    if HAS_TRITON:
-        if is_prompts:
-            if attention_mask is not None:
-                query_states, key_states, value_states, indices = unpading_input(
-                    query_states, key_states, value_states, attention_mask
+    @staticmethod
+    def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
+        """Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention
+
+        Args:
+            module (LlamaAttention): The origin LlamaAttention layer.
+        """
+        config = module.config
+        layer_idx = module.layer_idx
+
+        attn_qproj_w = module.q_proj.weight
+        attn_kproj_w = module.k_proj.weight
+        attn_vproj_w = module.v_proj.weight
+        attn_oproj_w = module.o_proj.weight
+
+        attn_layer = PadLlamaAttention(
+            config=config,
+            layer_idx=layer_idx,
+            attn_qproj_w=attn_qproj_w,
+            attn_kproj_w=attn_kproj_w,
+            attn_vproj_w=attn_vproj_w,
+            attn_oproj_w=attn_oproj_w,
+        )
+
+        return attn_layer
+
+    @torch.no_grad()
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        position_ids: torch.LongTensor,
+        block_tables: torch.Tensor = None,
+        k_cache: torch.Tensor = None,
+        v_cache: torch.Tensor = None,
+        is_prompts: bool = True,
+        sequence_lengths: torch.Tensor = None,
+        attention_mask: torch.Tensor = None,
+        kv_seq_len: int = 0,
+        cos_sin: Tuple[torch.Tensor] = None,
+        fd_inter_tensor: FDIntermTensors = None,
+        output_tensor: torch.Tensor = None,
+        sm_scale: int = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """
+        Args:
+            hidden_states (torch.Tensor): input to the layer of shape `(token_num, embed_dim)`
+            position_ids (torch.LongTensor), The position ids of input sequences.
+            block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence],
+                storing mapping of token_position_id -> block_id. Defaults to None.
+            k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
+            v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None.
+            is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True.
+            sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None.
+            attention_mask (torch.Tensor, optional): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)`
+                where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens.
+            kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0.
+            cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None.
+            fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for
+                storing intermediate values in flash-decoding. Defaults to None.
+            output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None.
+            sm_scale (int, optional): Used for flash attention. Defaults to None.
+        """
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+
+        if HAS_TRITON:
+            if is_prompts:
+                if attention_mask is not None:
+                    query_states, key_states, value_states, indices = unpading_input(
+                        query_states, key_states, value_states, attention_mask
+                    )
+                else:
+                    query_states = query_states.view(-1, self.num_heads, self.head_dim)
+                    key_states = key_states.view(-1, self.num_heads, self.head_dim)
+                    value_states = value_states.view(-1, self.num_heads, self.head_dim)
+            else:
+                query_states = query_states.squeeze(dim=1)
+                key_states = key_states.squeeze(dim=1)
+                value_states = value_states.squeeze(dim=1)
+
+            rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
+
+            block_size = k_cache.size(-2)
+
+            if is_prompts:
+                attn_output = context_attention_unpadded(
+                    q=query_states,
+                    k=key_states,
+                    v=value_states,
+                    k_cache=k_cache,
+                    v_cache=v_cache,
+                    context_lengths=sequence_lengths,
+                    block_tables=block_tables,
+                    block_size=block_size,
+                    output=output_tensor,
+                    max_seq_len=kv_seq_len,
+                    sm_scale=sm_scale,
+                )
+                if attention_mask is not None:
+                    attn_output = pad_input(attn_output, indices, bsz, q_len)
+            else:
+                copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
+                copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
+                attn_output = flash_decoding_attention(
+                    q=query_states,
+                    k_cache=k_cache,
+                    v_cache=v_cache,
+                    kv_seq_len=sequence_lengths,
+                    block_tables=block_tables,
+                    block_size=block_size,
+                    max_seq_len_in_batch=kv_seq_len,
+                    output=output_tensor,
+                    mid_output=fd_inter_tensor.mid_output,
+                    mid_output_lse=fd_inter_tensor.mid_output_lse,
+                    sm_scale=sm_scale,
+                )
+                attn_output = attn_output.squeeze(1)
+        else:
+            query_states = query_states.transpose(1, 2)
+            key_states = key_states.transpose(1, 2)
+            value_states = value_states.transpose(1, 2)
+
+            cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+            query_states = query_states.transpose(1, 2)
+            key_states = key_states.transpose(1, 2)
+            value_states = value_states.transpose(1, 2)
+
+            if is_prompts:
+                attn_output = PagedAttention.pad_context_forward(
+                    query_states,
+                    key_states,
+                    value_states,
+                    k_cache,
+                    v_cache,
+                    sequence_lengths,
+                    block_tables,
+                    attention_mask,
                 )
             else:
-                query_states = query_states.view(-1, self.num_heads, self.head_dim)
-                key_states = key_states.view(-1, self.num_heads, self.head_dim)
-                value_states = value_states.view(-1, self.num_heads, self.head_dim)
-        else:
-            query_states = query_states.squeeze(dim=1)
-            key_states = key_states.squeeze(dim=1)
-            value_states = value_states.squeeze(dim=1)
+                attn_output = PagedAttention.pad_decoding_forward(
+                    query_states,
+                    key_states,
+                    value_states,
+                    k_cache,
+                    v_cache,
+                    sequence_lengths,
+                    block_tables,
+                    attention_mask,
+                )
 
-        rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
+        attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+        attn_output = self.o_proj(attn_output)
 
-        block_size = k_cache.size(-2)
-
-        if is_prompts:
-            attn_output = context_attention_unpadded(
-                q=query_states,
-                k=key_states,
-                v=value_states,
-                k_cache=k_cache,
-                v_cache=v_cache,
-                context_lengths=sequence_lengths,
-                block_tables=block_tables,
-                block_size=block_size,
-                output=output_tensor,
-                max_seq_len=kv_seq_len,
-                sm_scale=sm_scale,
-            )
-            if attention_mask is not None:
-                attn_output = pad_input(attn_output, indices, bsz, q_len)
-        else:
-            copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
-            copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
-            attn_output = flash_decoding_attention(
-                q=query_states,
-                k_cache=k_cache,
-                v_cache=v_cache,
-                kv_seq_len=sequence_lengths,
-                block_tables=block_tables,
-                block_size=block_size,
-                max_seq_len_in_batch=kv_seq_len,
-                output=output_tensor,
-                mid_output=fd_inter_tensor.mid_output,
-                mid_output_lse=fd_inter_tensor.mid_output_lse,
-                sm_scale=sm_scale,
-            )
-            attn_output = attn_output.squeeze(1)
-    else:
-        query_states = query_states.transpose(1, 2)
-        key_states = key_states.transpose(1, 2)
-        value_states = value_states.transpose(1, 2)
-
-        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-
-        query_states = query_states.transpose(1, 2)
-        key_states = key_states.transpose(1, 2)
-        value_states = value_states.transpose(1, 2)
-
-        if is_prompts:
-            attn_output = PagedAttention.pad_context_forward(
-                query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
-            )
-        else:
-            attn_output = PagedAttention.pad_decoding_forward(
-                query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
-            )
-
-    attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim)
-    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-    attn_output = self.o_proj(attn_output)
-
-    return attn_output
+        return attn_output
 
 
 @torch.no_grad()
diff --git a/colossalai/inference/modeling/policy/nopadding_llama.py b/colossalai/inference/modeling/policy/nopadding_llama.py
index 3eaa59f74..aed72ef73 100644
--- a/colossalai/inference/modeling/policy/nopadding_llama.py
+++ b/colossalai/inference/modeling/policy/nopadding_llama.py
@@ -1,25 +1,18 @@
 from functools import partial
 
 import torch
-from transformers.models.llama.modeling_llama import (
-    LlamaAttention,
-    LlamaDecoderLayer,
-    LlamaFlashAttention2,
-    LlamaForCausalLM,
-    LlamaMLP,
-    LlamaModel,
-    LlamaRMSNorm,
-    LlamaSdpaAttention,
-)
+from torch.nn import Parameter
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
 
 from colossalai.inference.modeling.models.nopadding_llama import (
-    llama_attn_forward,
+    NopadLlamaAttention,
+    NopadLlamaMLP,
     llama_causal_lm_forward,
     llama_decoder_layer_forward,
     llama_model_forward,
-    nopad_mlp,
 )
 from colossalai.inference.utils import init_to_get_rotary
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
 
 # import colossalai
 from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
@@ -50,6 +43,27 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
 
     def module_policy(self):
         policy = super().module_policy()
+
+        decoder_attribute_replacement = {
+            "lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False),
+        }
+        policy[LlamaForCausalLM] = ModulePolicyDescription(
+            attribute_replacement=decoder_attribute_replacement,
+        )
+
+        policy[LlamaDecoderLayer] = ModulePolicyDescription(
+            sub_module_replacement=[
+                SubModuleReplacementDescription(
+                    suffix="mlp",
+                    target_module=NopadLlamaMLP,
+                ),
+                SubModuleReplacementDescription(
+                    suffix="self_attn",
+                    target_module=NopadLlamaAttention,
+                ),
+            ]
+        )
+
         self.shard_config._infer()
 
         infer_forward = llama_causal_lm_forward
@@ -68,28 +82,6 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
             description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
         )
 
-        infer_forward = nopad_mlp
-        method_replacement = {"forward": partial(infer_forward)}
-        self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP)
-
-        infer_forward = llama_attn_forward
-        method_replacement = {"forward": partial(infer_forward)}
-        self.append_or_create_method_replacement(
-            description=method_replacement, policy=policy, target_key=LlamaAttention
-        )
-
-        infer_forward = llama_attn_forward
-        method_replacement = {"forward": partial(infer_forward)}
-        self.append_or_create_method_replacement(
-            description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
-        )
-
-        infer_forward = llama_attn_forward
-        method_replacement = {"forward": partial(infer_forward)}
-        self.append_or_create_method_replacement(
-            description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
-        )
-
         infer_forward = None
         if HAS_TRITON_RMSNORM:
             infer_forward = get_triton_rmsnorm_forward()
diff --git a/colossalai/inference/modeling/policy/padding_llama.py b/colossalai/inference/modeling/policy/padding_llama.py
index 0c83189f8..9aa64f55b 100644
--- a/colossalai/inference/modeling/policy/padding_llama.py
+++ b/colossalai/inference/modeling/policy/padding_llama.py
@@ -1,18 +1,10 @@
 from functools import partial
 
 import torch
-from transformers.models.llama.modeling_llama import (
-    LlamaAttention,
-    LlamaDecoderLayer,
-    LlamaFlashAttention2,
-    LlamaForCausalLM,
-    LlamaModel,
-    LlamaRMSNorm,
-    LlamaSdpaAttention,
-)
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
 
 from colossalai.inference.modeling.models.padding_llama import (
-    llama_attn_forward,
+    PadLlamaAttention,
     llama_causal_lm_forward,
     llama_decoder_layer_forward,
     llama_model_forward,
@@ -49,105 +41,16 @@ class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
 
     def module_policy(self):
         policy = super().module_policy()
-        decoder_attribute_replacement = {
-            "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
-            "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
-            "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
-            // self.shard_config.tensor_parallel_size,
-        }
-        if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
-            from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
 
-            policy[LlamaDecoderLayer] = ModulePolicyDescription(
-                attribute_replacement=decoder_attribute_replacement,
-                sub_module_replacement=[
-                    SubModuleReplacementDescription(
-                        suffix="self_attn.q_proj",
-                        target_module=ColCaiQuantLinear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="self_attn.k_proj",
-                        target_module=ColCaiQuantLinear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="self_attn.v_proj",
-                        target_module=ColCaiQuantLinear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="self_attn.o_proj",
-                        target_module=RowCaiQuantLinear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="mlp.gate_proj",
-                        target_module=ColCaiQuantLinear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="mlp.up_proj",
-                        target_module=ColCaiQuantLinear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="mlp.down_proj",
-                        target_module=RowCaiQuantLinear,
-                        kwargs={"split_num": 1},
-                    ),
-                ],
-            )
+        policy[LlamaDecoderLayer] = ModulePolicyDescription(
+            sub_module_replacement=[
+                SubModuleReplacementDescription(
+                    suffix="self_attn",
+                    target_module=PadLlamaAttention,
+                ),
+            ]
+        )
 
-        elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
-            from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
-            from colossalai.inference.quant.smoothquant.models.parallel_linear import (
-                ColW8A8BFP32OFP32Linear,
-                RowW8A8B8O8Linear,
-                RowW8A8BFP32O32LinearSiLU,
-                RowW8A8BFP32OFP32Linear,
-            )
-
-            policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
-                attribute_replacement=decoder_attribute_replacement,
-                sub_module_replacement=[
-                    SubModuleReplacementDescription(
-                        suffix="self_attn.q_proj",
-                        target_module=RowW8A8B8O8Linear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="self_attn.k_proj",
-                        target_module=RowW8A8B8O8Linear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="self_attn.v_proj",
-                        target_module=RowW8A8B8O8Linear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="self_attn.o_proj",
-                        target_module=ColW8A8BFP32OFP32Linear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="mlp.gate_proj",
-                        target_module=RowW8A8BFP32O32LinearSiLU,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="mlp.up_proj",
-                        target_module=RowW8A8BFP32OFP32Linear,
-                        kwargs={"split_num": 1},
-                    ),
-                    SubModuleReplacementDescription(
-                        suffix="mlp.down_proj",
-                        target_module=ColW8A8BFP32OFP32Linear,
-                        kwargs={"split_num": 1},
-                    ),
-                ],
-            )
         self.shard_config._infer()
 
         infer_forward = llama_causal_lm_forward
@@ -166,24 +69,6 @@ class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
             description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
         )
 
-        infer_forward = llama_attn_forward
-        method_replacement = {"forward": partial(infer_forward)}
-        self.append_or_create_method_replacement(
-            description=method_replacement, policy=policy, target_key=LlamaAttention
-        )
-
-        infer_forward = llama_attn_forward
-        method_replacement = {"forward": partial(infer_forward)}
-        self.append_or_create_method_replacement(
-            description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
-        )
-
-        infer_forward = llama_attn_forward
-        method_replacement = {"forward": partial(infer_forward)}
-        self.append_or_create_method_replacement(
-            description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
-        )
-
         infer_forward = None
         if HAS_TRITON_RMSNORM:
             infer_forward = get_triton_rmsnorm_forward()
diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py
index 4bba24503..37fcd504c 100644
--- a/colossalai/kernel/triton/flash_decoding.py
+++ b/colossalai/kernel/triton/flash_decoding.py
@@ -143,8 +143,7 @@ def _flash_decoding_fwd_reduce_kernel(
     stride_o_lset,
     stride_o_lseh,
     stride_o_lseb,
-    stride_ob,
-    stride_ol,
+    stride_ot,
     stride_oh,
     stride_od,
     BLOCK_KV: tl.constexpr,
@@ -180,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel(
         m_i = m_ij
 
     acc = acc / l
-    offsets_O = cur_seq_idx * stride_ob + cur_head_idx * stride_oh + offsets_dmodel
+    offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
     tl.store(O + offsets_O, acc.to(O.type.element_ty))
     return
 
@@ -212,7 +211,7 @@ def flash_decoding_attention(
             records the (kv) sequence lengths incorporating past kv sequence lengths.
         block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
         max_seq_len_in_batch (int): Maximum sequence length in the batch.
-        output (torch.Tensor):  [bsz, 1, num_heads, head_dim]
+        output (torch.Tensor):  [bsz, num_heads, head_dim]
         mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
             Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
         mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@@ -294,7 +293,7 @@ def flash_decoding_attention(
         HEAD_DIM=head_dim,
     )
 
-    output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
+    output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
 
     grid = (triton.next_power_of_2(bsz), num_heads)
 
@@ -314,7 +313,6 @@ def flash_decoding_attention(
         output.stride(0),
         output.stride(1),
         output.stride(2),
-        output.stride(3),
         BLOCK_KV=block_size,
         HEAD_DIM=head_dim,
     )
diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh
index bdd79836e..6870ed384 100755
--- a/examples/inference/run_benchmark.sh
+++ b/examples/inference/run_benchmark.sh
@@ -25,10 +25,20 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
 # benchmark llama2-7b one single GPU
 
 for bsz in 16 32 64; do
-    python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
+    python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_512_256.txt
 done
 
 
 for bsz in 16 32 64; do
-    python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
+    python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_256.txt
+done
+
+
+for bsz in 16 32 64; do
+    python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256_128.txt
+done
+
+
+for bsz in 16 32 64; do
+    python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024_128.txt
 done
diff --git a/tests/test_infer_ops/triton/kernel_utils.py b/tests/test_infer_ops/triton/kernel_utils.py
index 7c3bc5ca6..22167ded0 100644
--- a/tests/test_infer_ops/triton/kernel_utils.py
+++ b/tests/test_infer_ops/triton/kernel_utils.py
@@ -69,6 +69,7 @@ def torch_attn_ref(
             f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
         )
     out = out.transpose(1, 2).contiguous()
+    out = out.squeeze(1)
     return out
 
 
diff --git a/tests/test_infer_ops/triton/test_decoding_attn.py b/tests/test_infer_ops/triton/test_decoding_attn.py
index a49ee3146..5eac026bb 100644
--- a/tests/test_infer_ops/triton/test_decoding_attn.py
+++ b/tests/test_infer_ops/triton/test_decoding_attn.py
@@ -94,7 +94,7 @@ def test_flash_decoding(
     max_seq_len_in_b = kv_seq_lengths.max().item()
     # The maximum block length splitted on kv should be the kv cache block size
     kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
-    output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
+    output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
     mid_output = torch.empty(
         size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
     )
@@ -189,7 +189,7 @@ def bench_kernel(
         block_tables = block_tables.to(device=device)
         # the maximum block length splitted on kv should be the kv cache block size
         kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
-        output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
+        output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
         mid_output = torch.empty(
             size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
         )