From 19e1a5cf16ead982eb8818cd69e41b06a5d23b20 Mon Sep 17 00:00:00 2001
From: Hongxin Liu <lhx0217@gmail.com>
Date: Wed, 27 Mar 2024 11:19:32 +0800
Subject: [PATCH] [shardformer] update colo attention to support custom mask
 (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
---
 .github/workflows/build_on_pr.yml             |   4 +-
 .github/workflows/build_on_schedule.yml       |   2 +-
 .../compatiblity_test_on_dispatch.yml         |   2 +-
 .github/workflows/compatiblity_test_on_pr.yml |   2 +-
 .../compatiblity_test_on_schedule.yml         |   2 +-
 colossalai/kernel/kernel_loader.py            |  24 +-
 colossalai/nn/layer/colo_attention.py         | 209 --------
 colossalai/shardformer/layer/__init__.py      |   3 +
 colossalai/shardformer/layer/attn.py          | 269 +++++++++++
 colossalai/shardformer/modeling/blip2.py      |  39 +-
 colossalai/shardformer/modeling/chatglm2.py   | 125 ++---
 colossalai/shardformer/modeling/gpt2.py       | 448 +++++++++++++-----
 colossalai/shardformer/modeling/gptj.py       | 363 ++++++++++----
 colossalai/shardformer/modeling/llama.py      | 197 ++++++--
 colossalai/shardformer/modeling/opt.py        | 335 +++++++++----
 colossalai/shardformer/modeling/vit.py        |  35 +-
 colossalai/shardformer/modeling/whisper.py    | 302 +++++++++---
 colossalai/shardformer/policies/gpt2.py       |  55 ++-
 colossalai/shardformer/policies/gptj.py       |  51 +-
 colossalai/shardformer/policies/llama.py      |  10 +
 colossalai/shardformer/policies/opt.py        |  58 ++-
 colossalai/shardformer/policies/whisper.py    |  24 +-
 colossalai/testing/comparison.py              |  30 +-
 extensions/README.md                          |   4 +-
 extensions/__init__.py                        |  10 +-
 extensions/base_extension.py                  |   4 +-
 extensions/cpu_adam/cpu_adam_arm.py           |   4 +-
 extensions/cpu_adam/cpu_adam_x86.py           |   8 +-
 extensions/cuda_extension.py                  |   4 +-
 extensions/flash_attention/__init__.py        |  12 +-
 .../flash_attention_dao_cuda.py               |  99 ++--
 .../flash_attention/flash_attention_npu.py    |  63 +--
 .../flash_attention_sdpa_cuda.py              |  56 +++
 .../flash_attention_xformers_cuda.py          |  94 ----
 setup.py                                      |   4 +-
 .../test_shardformer/test_flash_attention.py  | 147 ++++++
 tests/test_shardformer/test_model/_utils.py   |  23 +-
 .../test_model/test_shard_blip2.py            |  51 +-
 .../test_model/test_shard_chatglm2.py         |  69 ++-
 .../test_model/test_shard_gpt2.py             |  77 ++-
 .../test_model/test_shard_gptj.py             |  78 ++-
 .../test_model/test_shard_llama.py            |   4 +-
 .../test_model/test_shard_opt.py              |  90 +++-
 .../test_model/test_shard_t5.py               |  56 ++-
 tests/test_utils/test_flash_attention.py      | 167 -------
 45 files changed, 2543 insertions(+), 1170 deletions(-)
 delete mode 100644 colossalai/nn/layer/colo_attention.py
 create mode 100644 colossalai/shardformer/layer/attn.py
 create mode 100644 extensions/flash_attention/flash_attention_sdpa_cuda.py
 delete mode 100644 extensions/flash_attention/flash_attention_xformers_cuda.py
 create mode 100644 tests/test_shardformer/test_flash_attention.py
 delete mode 100644 tests/test_utils/test_flash_attention.py

diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index b01d15490..5bdadca78 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -117,7 +117,7 @@ jobs:
           cd TensorNVMe
           conda install cmake
           pip install -r requirements.txt
-          pip install -v .
+          DISABLE_URING=1 pip install -v .
 
       - name: Store TensorNVMe Cache
         run: |
@@ -201,4 +201,4 @@ jobs:
         uses: actions/upload-artifact@v3
         with:
           name: report
-          path: report/
\ No newline at end of file
+          path: report/
diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml
index 3ff19b37b..e560d0c00 100644
--- a/.github/workflows/build_on_schedule.yml
+++ b/.github/workflows/build_on_schedule.yml
@@ -44,7 +44,7 @@ jobs:
           cd TensorNVMe
           conda install cmake
           pip install -r requirements.txt
-          pip install -v .
+          DISABLE_URING=1 pip install -v .
 
       - uses: actions/checkout@v2
         if: steps.check-avai.outputs.avai == 'true'
diff --git a/.github/workflows/compatiblity_test_on_dispatch.yml b/.github/workflows/compatiblity_test_on_dispatch.yml
index 764938806..95a94c27b 100644
--- a/.github/workflows/compatiblity_test_on_dispatch.yml
+++ b/.github/workflows/compatiblity_test_on_dispatch.yml
@@ -66,7 +66,7 @@ jobs:
           cd TensorNVMe
           apt update && apt install -y cmake
           pip install -r requirements.txt
-          pip install -v .
+          DISABLE_URING=1 pip install -v .
       - uses: actions/checkout@v2
         with:
           ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml
index f582b3090..aef4816ef 100644
--- a/.github/workflows/compatiblity_test_on_pr.yml
+++ b/.github/workflows/compatiblity_test_on_pr.yml
@@ -60,7 +60,7 @@ jobs:
           cd TensorNVMe
           apt update && apt install -y cmake
           pip install -r requirements.txt
-          pip install -v .
+          DISABLE_URING=1 pip install -v .
       - uses: actions/checkout@v2
         with:
           ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
diff --git a/.github/workflows/compatiblity_test_on_schedule.yml b/.github/workflows/compatiblity_test_on_schedule.yml
index 3348b51ec..3dc8a5a32 100644
--- a/.github/workflows/compatiblity_test_on_schedule.yml
+++ b/.github/workflows/compatiblity_test_on_schedule.yml
@@ -56,7 +56,7 @@ jobs:
           cd TensorNVMe
           apt update && apt install -y cmake
           pip install -r requirements.txt
-          pip install -v .
+          DISABLE_URING=1 pip install -v .
       - uses: actions/checkout@v2
         with:
           ssh-key: ${{ secrets.SSH_KEY_FOR_CI }}
diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py
index 148c3e3fc..353e29b3d 100644
--- a/colossalai/kernel/kernel_loader.py
+++ b/colossalai/kernel/kernel_loader.py
@@ -6,7 +6,7 @@ from .extensions import (
     CpuAdamX86Extension,
     FlashAttentionDaoCudaExtension,
     FlashAttentionNpuExtension,
-    FlashAttentionXformersCudaExtension,
+    FlashAttentionSdpaCudaExtension,
     FusedOptimizerCudaExtension,
     LayerNormCudaExtension,
     MoeCudaExtension,
@@ -65,9 +65,9 @@ class KernelLoader:
         else:
             usable_exts = []
             for ext in exts:
-                if ext.is_hardware_available():
+                if ext.is_available():
                     # make sure the machine is compatible during kernel loading
-                    ext.assert_hardware_compatible()
+                    ext.assert_compatible()
                     usable_exts.append(ext)
 
         assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine."
@@ -106,4 +106,20 @@ class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader):
 
 
 class FlashAttentionLoader(KernelLoader):
-    REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension]
+    REGISTRY = [
+        FlashAttentionNpuExtension,
+        FlashAttentionDaoCudaExtension,
+        FlashAttentionSdpaCudaExtension,
+    ]
+
+
+class FlashAttentionWithPaddingMaskLoader(KernelLoader):
+    REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension]
+
+
+class FlashAttentionWithCustomMaskLoader(KernelLoader):
+    REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]
+
+
+class FlashAttentionForFloatAndCustomMaskLoader(KernelLoader):
+    REGISTRY = [FlashAttentionSdpaCudaExtension]
diff --git a/colossalai/nn/layer/colo_attention.py b/colossalai/nn/layer/colo_attention.py
deleted file mode 100644
index 0b7011e8e..000000000
--- a/colossalai/nn/layer/colo_attention.py
+++ /dev/null
@@ -1,209 +0,0 @@
-import enum
-import math
-import warnings
-from dataclasses import dataclass
-from typing import Iterable, Optional, Tuple
-
-import torch
-import torch.nn.functional as F
-from einops import rearrange
-
-from colossalai.accelerator import get_accelerator
-from colossalai.kernel.kernel_loader import FlashAttentionLoader
-
-
-@dataclass
-class SeqLenInfo:
-    seqlens: Iterable[int] = None
-    indices: torch.Tensor = None
-    max_seqlen: int = None
-    cu_seqlens: torch.Tensor = None
-
-    @staticmethod
-    def materialize(
-        attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
-    ):
-        if attn_mask is not None:
-            indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
-            seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
-        else:
-            batch_size, tgt_len = size[0], size[1]
-            indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
-            seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
-        max_seqlen = max(seqlens)
-        cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
-        return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
-
-
-class AttnMaskType(enum.Enum):
-    padding = 1
-    causal = 2
-    paddedcausal = 3
-
-
-class Unpad(torch.autograd.Function):
-    """
-    Adapted from
-    https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
-    """
-
-    @staticmethod
-    def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
-        ctx.save_for_backward(indices)
-        # [b, s, ...]
-        assert tensor.ndim >= 3
-        ctx.bsz = tensor.shape[0]
-        out = rearrange(tensor, "b s ... -> (b s) ...")
-        ctx.shape = out.shape
-        # [ntokens, ...]
-        return out[indices]
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        (indices,) = ctx.saved_tensors
-        # [ntokens, ...]
-        grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
-        grad[indices] = grad_output
-        grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
-        # [b, s, ...]
-        return grad, None
-
-
-class Repad(torch.autograd.Function):
-    """
-    Adapted from
-    https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
-    """
-
-    @staticmethod
-    def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
-        ctx.save_for_backward(indices)
-        # [ntokens, ...]
-        tensor = tensor
-        out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
-        # [b*s, ...]
-        out[indices] = tensor
-        return out
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        (indices,) = ctx.saved_tensors
-        # [b*s, ...]
-        grad = grad_output[indices]
-        # [ntokens, ...]
-        return grad, None, None, None
-
-
-class ColoAttention(torch.nn.Module):
-    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
-        super().__init__()
-        assert (
-            embed_dim % num_heads == 0
-        ), f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
-        if scale is not None:
-            self.scale = scale
-        else:
-            self.scale = 1 / math.sqrt(embed_dim // num_heads)
-        self.dropout = dropout
-
-        self.attn = FlashAttentionLoader().load()
-
-    @staticmethod
-    def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
-        return Unpad.apply(tensor, indices)
-
-    @staticmethod
-    def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
-        return Repad.apply(tensor, indices, batch_size, seq_len)
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        attn_mask: Optional[torch.Tensor] = None,
-        origin_attn_mask: Optional[torch.Tensor] = None,
-        attn_mask_type: Optional[AttnMaskType] = None,
-        bias: Optional[torch.Tensor] = None,
-    ):
-        """
-        ColoAttention
-
-        Args:
-            q: (batch, q_seqlen, nheads, headdim)
-            k: (batch, kv_seqlen, nheads, headdim)
-            v: (batch, kv_seqlen, nheads, headdim)
-            origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
-            bias: will not be used
-        Return:
-            attn_out: (batch, q_seqlen, nheads, headdim).
-        """
-        # if flash attention is not applicable, switch to memory effcient attention
-        if self.attn.__name__ == "flash_attention" and (
-            query.dtype not in [torch.float16, torch.bfloat16] or bias != None
-        ):
-            warnings.warn(
-                f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
-            )
-            self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
-
-        padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
-        causal = attn_mask_type is not None and attn_mask_type.value > 1
-
-        batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
-        # unpad
-        seq_len_info_q = None
-        seq_len_info_kv = None
-        if padded:
-            # bert style, unpad process
-            assert (
-                attn_mask is not None
-            ), f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
-            assert attn_mask.dim() == 2, (
-                "attention mask is supposed to have shape (batch_size, seq_len), "
-                + f"but got {attn_mask.dim()} dimensions."
-            )
-
-            # bert style
-            if tgt_len == src_len:
-                seq_len_info_q = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
-                if batch_size > 1:
-                    query, key, value = self.unpad(
-                        torch.stack([query, key, value], dim=2), seq_len_info_q.indices
-                    ).unbind(dim=1)
-                else:
-                    query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
-                seq_len_info_kv = seq_len_info_q
-            else:
-                seq_len_info_q = SeqLenInfo.materialize(size=(batch_size, tgt_len), device=query.device)
-                seq_len_info_kv = SeqLenInfo.materialize(attn_mask=attn_mask, device=query.device)
-                if batch_size > 1:
-                    query = rearrange(query, "b s ... -> c (b s) ...", c=1)
-                    key, value = self.unpad(torch.stack([query, key, value], dim=2), seq_len_info_kv.indices).unbind(
-                        dim=1
-                    )
-                else:
-                    query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
-
-        out = self.attn(
-            query,
-            key,
-            value,
-            seq_len_info_q=seq_len_info_q,
-            seq_len_info_kv=seq_len_info_kv,
-            origin_attn_mask=origin_attn_mask,
-            dropout_p=self.dropout,
-            scale=self.scale,
-            causal=causal,
-            padded=padded,
-        )
-
-        # repad
-        if padded:
-            if batch_size > 1:
-                out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
-            out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
-
-        if len(out.shape) == 4:
-            out = rearrange(out, "b s h d -> b s (h d)")
-        return out
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index 56e8b08c4..c9b4317a6 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -1,3 +1,4 @@
+from .attn import AttnMaskType, ColoAttention
 from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
 from .embedding import Embedding1D, VocabParallelEmbedding1D
 from .linear import Linear1D_Col, Linear1D_Row
@@ -23,4 +24,6 @@ __all__ = [
     "FusedRMSNorm",
     "FusedLinear1D_Col",
     "ParallelModule",
+    "AttnMaskType",
+    "ColoAttention",
 ]
diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py
new file mode 100644
index 000000000..f3f6e59d3
--- /dev/null
+++ b/colossalai/shardformer/layer/attn.py
@@ -0,0 +1,269 @@
+from enum import Enum
+from typing import Callable, Dict, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+
+from colossalai.kernel.kernel_loader import (
+    FlashAttentionForFloatAndCustomMaskLoader,
+    FlashAttentionLoader,
+    FlashAttentionWithCustomMaskLoader,
+    FlashAttentionWithPaddingMaskLoader,
+    KernelLoader,
+)
+
+__all__ = [
+    "AttnMaskType",
+    "ColoAttention",
+]
+
+
+class AttnMaskType(Enum):
+    CUSTOM = 0
+    PADDED = 1
+    CAUSAL = 2
+    PADDED_CAUSAL = 3
+
+
+def invert_mask(mask: torch.Tensor) -> torch.Tensor:
+    """Invert the mask tensor.
+
+    Args:
+        mask (torch.Tensor): Mask tensor. Shape should be [B, 1, Sq, Skv]
+
+    Returns:
+        torch.Tensor: Inverted mask tensor.
+    """
+    inverted_mask = 1.0 - mask
+    return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(mask.dtype).min)
+
+
+# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
+def get_pad_info(padding_mask: torch.Tensor) -> Tuple[int, torch.Tensor, torch.Tensor]:
+    """Get padding information from padding mask.
+
+    Args:
+        padding_mask (torch.Tensor): Padding mask tensor. Shape should be [B, S]
+
+    Returns:
+        Tuple[int, torch.Tensor, torch.Tensor]: Tuple of (max_seq_len, cu_seqlens, indices)
+    """
+    seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
+    indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
+    max_seqlen_in_batch = seqlens_in_batch.max().item()
+    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+    return max_seqlen_in_batch, cu_seqlens, indices
+
+
+class ColoAttention:
+    _kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
+
+    @staticmethod
+    def _init_kernels_dispatch():
+        if ColoAttention._kernel_dispatch_map is None:
+            # fp16/bf16
+            half_dispatch_map = {
+                None: FlashAttentionLoader(),
+                AttnMaskType.CUSTOM: FlashAttentionWithCustomMaskLoader(),
+                AttnMaskType.PADDED: FlashAttentionWithPaddingMaskLoader(),
+                AttnMaskType.CAUSAL: FlashAttentionLoader(),
+                AttnMaskType.PADDED_CAUSAL: FlashAttentionWithPaddingMaskLoader(),
+            }
+            # fp32
+            float_dispatch_map = {
+                None: FlashAttentionForFloatAndCustomMaskLoader(),
+                AttnMaskType.CUSTOM: FlashAttentionForFloatAndCustomMaskLoader(),
+                AttnMaskType.CAUSAL: FlashAttentionForFloatAndCustomMaskLoader(),
+            }
+            ColoAttention._kernel_dispatch_map = {
+                torch.float16: half_dispatch_map,
+                torch.bfloat16: half_dispatch_map,
+                torch.float32: float_dispatch_map,
+            }
+
+    @staticmethod
+    def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
+        ColoAttention._init_kernels_dispatch()
+        if (
+            dtype not in ColoAttention._kernel_dispatch_map
+            or mask_type not in ColoAttention._kernel_dispatch_map[dtype]
+        ):
+            raise ValueError(
+                "FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
+            )
+        # lazy load
+        if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
+            ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
+                mask_type
+            ].load()
+        return ColoAttention._kernel_dispatch_map[dtype][mask_type]
+
+    @staticmethod
+    def prepare_attn_kwargs(
+        shape_4d: Tuple[int],
+        dtype: torch.dtype,
+        device: torch.device,
+        q_padding_mask: Optional[torch.Tensor] = None,
+        kv_padding_mask: Optional[torch.Tensor] = None,
+        is_causal: bool = False,
+    ) -> Dict[str, torch.Tensor]:
+        """Return a dictionary of keyword arguments for attention function. It supports 4 mask type.
+        1. custom mask: no padding mask and is_causal=False, return {}, users should handle attention mask by themselves.
+        2. padded mask: recv padding mask and is_causal=False, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
+        3. causal mask: no padding mask and is_causal=True, return {attention_mask, attention_mask_type}.
+        4. padded causal mask: recv padding mask and is_causal=True, return {attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, q_indices, kv_indices}.
+
+        Args:
+            shape_4d (Tuple[int]): Should be (B, 1, Sq, Skv)
+            dtype (torch.dtype): Dtype of attention mask, generally should be ``hidden_states.dtype``
+            device (torch.device): Device of attention mask, generally should be ``hidden_states.device``
+            q_padding_mask (Optional[torch.Tensor], optional): Padding mask of query. It should be a long tensor or int tensor.
+                The shape should be [B, Sq]. ``1`` means valid token, and ``0`` means padding token. Defaults to None.
+            kv_padding_mask (Optional[torch.Tensor], optional): Padding mask of key and value. It should be a long tensor or int tensor.
+                The shape should be [B, Skv]. ``1`` means valid token, and ``0`` means padding token.
+                If it's None and ``q_padding_mask`` is not None, it will be set to ``q_padding_mask``. Defaults to None.
+            is_causal (bool, optional): Whether to use causal attention mask. Defaults to False.
+
+        Returns:
+            Dict[str, torch.Tensor]: Dictionary of keyword arguments for attention function.
+        """
+        if q_padding_mask is None and not is_causal:
+            return {}
+        assert len(shape_4d) == 4 and shape_4d[1] == 1
+        b, _, s_q, s_kv = shape_4d
+        outputs = {}
+        if (q_padding_mask is None or q_padding_mask.bool().all()) and (
+            kv_padding_mask is None or kv_padding_mask.bool().all()
+        ):
+            # no padding
+            assert is_causal
+            outputs["attention_mask_type"] = AttnMaskType.CAUSAL
+            attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device).tril(diagonal=0).expand(b, s_q, s_kv)
+        else:
+            if kv_padding_mask is None:
+                # self attention
+                kv_padding_mask = q_padding_mask
+            assert q_padding_mask.shape == (b, s_q) and kv_padding_mask.shape == (
+                b,
+                s_kv,
+            ), f"q_padding_mask shape {q_padding_mask.shape} and kv_padding_mask shape {kv_padding_mask.shape} should be the same. ({shape_4d})"
+            attention_mask = torch.einsum("bi,bj->bij", q_padding_mask, kv_padding_mask).to(dtype=dtype, device=device)
+            max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
+            max_seqlen_kv, cu_seqlens_kv, kv_indices = get_pad_info(kv_padding_mask)
+            outputs.update(
+                {
+                    "cu_seqlens_q": cu_seqlens_q,
+                    "cu_seqlens_kv": cu_seqlens_kv,
+                    "max_seqlen_q": max_seqlen_q,
+                    "max_seqlen_kv": max_seqlen_kv,
+                    "q_indices": q_indices,
+                    "kv_indices": kv_indices,
+                }
+            )
+            if is_causal:
+                outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
+                attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
+            else:
+                outputs["attention_mask_type"] = AttnMaskType.PADDED
+        attention_mask = invert_mask(attention_mask).unsqueeze(1)
+        outputs["attention_mask"] = attention_mask
+        return outputs
+
+    @staticmethod
+    def attention(
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        attention_mask_type: AttnMaskType = AttnMaskType.CUSTOM,
+        cu_seqlens_q: Optional[torch.Tensor] = None,
+        cu_seqlens_kv: Optional[torch.Tensor] = None,
+        max_seqlen_q: Optional[int] = None,
+        max_seqlen_kv: Optional[int] = None,
+        q_indices: Optional[torch.Tensor] = None,
+        kv_indices: Optional[torch.Tensor] = None,
+        dropout_p: float = 0.0,
+        scale: Optional[float] = None,
+    ) -> torch.Tensor:
+        """Flash Attention function. It supports 4 mask type.
+        1. custom mask: recv attention_mask
+        2. padded mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
+        3. causal mask: recv attention_mask, attention_mask_type
+        4. padded causal mask: recv attention_mask, attention_mask_type, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, indices
+
+        Args:
+            q (torch.Tensor): Query tensor. Shape should be [B, N, Sq, D]
+            k (torch.Tensor): Key tensor. Shape should be [B, N, Skv, D]
+            v (torch.Tensor): Value tensor. Shape should be [B, N, Skv, D]
+            attention_mask (Optional[torch.Tensor], optional): Attention mask tensor. Shape should be [B, 1, Sq, Skv]. Defaults to None.
+            attention_mask_type (AttnMaskType, optional): Attention mask type. Defaults to AttnMaskType.CUSTOM.
+            cu_seqlens_q (Optional[torch.Tensor], optional): The cumulative sequence lengths
+                of the sequences in the batch, used to index into q.
+                Shape should be [B+1]. Defaults to None.
+            cu_seqlens_kv (Optional[torch.Tensor], optional): The cumulative sequence lengths
+                of the sequences in the batch, used to index into kv.
+                Shape should be [B+1]. Defaults to None.
+            max_seqlen_q (Optional[int], optional): Maximum query sequence length in the batch. Defaults to None.
+            max_seqlen_kv (Optional[int], optional): Maximum key/value sequence length in the batch. Defaults to None.
+            indices (Optional[torch.Tensor], optional): The indices of non-masked tokens from the flattened input sequence.
+                Shape should be [NUM_TOKENS]. Defaults to None.
+            dropout_p (float, optional): Dropout probability. Defaults to 0.0.
+            scale (Optional[float], optional): Scaling factor applied prior to softmax. Defaults to None.
+
+        Returns:
+            torch.Tensor: Output tensor. Shape should be [B, N, Sq, D]
+        """
+        # known issue: sdpa does not support attention mask which contains whole row of masked tokens, which leads to nan
+        # this case is usaul when padding mask is used and self attention is performed
+        # thus, we don't use sdpa when padding mask is used
+        # sanity check
+        if attention_mask is not None:
+            assert torch.is_floating_point(attention_mask), "attention_mask should be a floating point tensor."
+            if attention_mask_type in (AttnMaskType.CUSTOM, AttnMaskType.CAUSAL):
+                assert (
+                    cu_seqlens_q is None
+                    and cu_seqlens_kv is None
+                    and max_seqlen_q is None
+                    and max_seqlen_kv is None
+                    and q_indices is None
+                    and kv_indices is None
+                )
+                if attention_mask_type == AttnMaskType.CUSTOM:
+                    assert not torch.all(attention_mask != 0, dim=-1).any()
+            elif attention_mask_type in (
+                AttnMaskType.PADDED,
+                AttnMaskType.PADDED_CAUSAL,
+            ):
+                assert (
+                    cu_seqlens_q is not None
+                    and cu_seqlens_kv is not None
+                    and max_seqlen_q is not None
+                    and max_seqlen_kv is not None
+                    and q_indices is not None
+                    and kv_indices is not None
+                )
+        else:
+            # if attention_mask is None, attention_mask_type should be the default value
+            assert attention_mask_type == AttnMaskType.CUSTOM
+        # kernel dispatch
+        mask_type = attention_mask_type if attention_mask is not None else None
+        attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
+        is_causal = attention_mask is not None and attention_mask_type in (
+            AttnMaskType.CAUSAL,
+            AttnMaskType.PADDED_CAUSAL,
+        )
+        return attn_func(
+            q,
+            k,
+            v,
+            dropout_p=dropout_p,
+            scale=scale,
+            attention_mask=attention_mask,
+            is_causal=is_causal,
+            cu_seqlens_q=cu_seqlens_q,
+            cu_seqlens_kv=cu_seqlens_kv,
+            max_seqlen_q=max_seqlen_q,
+            max_seqlen_kv=max_seqlen_kv,
+            q_indices=q_indices,
+            kv_indices=kv_indices,
+        )
diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py
index d5c10541a..bd84c87c6 100644
--- a/colossalai/shardformer/modeling/blip2.py
+++ b/colossalai/shardformer/modeling/blip2.py
@@ -3,6 +3,8 @@ from typing import Optional, Tuple
 import torch
 import torch.nn as nn
 
+from colossalai.shardformer.layer import ColoAttention
+
 
 def forward_fn():
     def forward(
@@ -62,8 +64,6 @@ def forward_fn():
 def get_blip2_flash_attention_forward():
     from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
 
-    from colossalai.nn.layer.colo_attention import ColoAttention
-
     def forward(
         self: Blip2Attention,
         hidden_states: torch.Tensor,
@@ -71,16 +71,25 @@ def get_blip2_flash_attention_forward():
         output_attentions: Optional[bool] = False,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         """Input shape: Batch x Time x Channel"""
-
+        assert head_mask is None, "head_mask is not supported in FlashAttention"
         bsz, tgt_len, embed_dim = hidden_states.size()
         mixed_qkv = self.qkv(hidden_states)
-        mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, -1).permute(2, 0, 1, 3, 4)
-        query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
-
-        attention = ColoAttention(
-            embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout.p, scale=self.scale
+        mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+        query_states, key_states, value_states = (
+            mixed_qkv[0],
+            mixed_qkv[1],
+            mixed_qkv[2],
         )
-        context_layer = attention(query_states, key_states, value_states)
+
+        dropout_p = self.dropout.p if self.training else 0.0
+        context_layer = ColoAttention.attention(
+            query_states,
+            key_states,
+            value_states,
+            dropout_p=dropout_p,
+            scale=self.scale,
+        )
+        context_layer = context_layer.permute(0, 2, 1, 3).reshape(bsz, tgt_len, self.embed_dim)
 
         output = self.projection(context_layer)
         outputs = (output, None)
@@ -93,7 +102,11 @@ def get_blip2_flash_attention_forward():
 def get_jit_fused_blip2_QFormer_self_output_forward():
     from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerSelfOutput
 
-    def forward(self: Blip2QFormerSelfOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+    def forward(
+        self: Blip2QFormerSelfOutput,
+        hidden_states: torch.Tensor,
+        input_tensor: torch.Tensor,
+    ) -> torch.Tensor:
         hidden_states = self.dense(hidden_states)
         hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
         hidden_states = self.LayerNorm(hidden_states)
@@ -105,7 +118,11 @@ def get_jit_fused_blip2_QFormer_self_output_forward():
 def get_jit_fused_blip2_QFormer_output_forward():
     from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerOutput
 
-    def forward(self: Blip2QFormerOutput, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+    def forward(
+        self: Blip2QFormerOutput,
+        hidden_states: torch.Tensor,
+        input_tensor: torch.Tensor,
+    ) -> torch.Tensor:
         hidden_states = self.dense(hidden_states)
         hidden_states = self.dropout_add(hidden_states, input_tensor, self.dropout.p, self.dropout.training)
         hidden_states = self.LayerNorm(hidden_states)
diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py
index d13bd3492..a3e000e6e 100644
--- a/colossalai/shardformer/modeling/chatglm2.py
+++ b/colossalai/shardformer/modeling/chatglm2.py
@@ -1,4 +1,5 @@
 """ PyTorch ChatGLM model. """
+
 from typing import List, Optional, Tuple
 
 import torch
@@ -9,63 +10,49 @@ from transformers.utils import logging
 
 from colossalai.pipeline.stage_manager import PipelineStageManager
 from colossalai.shardformer import ShardConfig
+from colossalai.shardformer.layer import AttnMaskType, ColoAttention
 from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
 from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
 
 
 def get_flash_core_attention_forward():
-    from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
     from .chatglm2_6b.modeling_chatglm import CoreAttention
 
     def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
-        pytorch_major_version = int(torch.__version__.split(".")[0])
-        if pytorch_major_version >= 2:
-            query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
-            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
-                context_layer = torch.nn.functional.scaled_dot_product_attention(
-                    query_layer, key_layer, value_layer, is_causal=True
-                )
-            else:
-                if attention_mask is not None:
-                    attention_mask = ~attention_mask
-                context_layer = torch.nn.functional.scaled_dot_product_attention(
-                    query_layer, key_layer, value_layer, attention_mask
-                )
-            context_layer = context_layer.permute(2, 0, 1, 3)
-            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
-            context_layer = context_layer.reshape(*new_context_layer_shape)
+        query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
+        if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
+            attention_mask_type = AttnMaskType.CAUSAL
+            attn_bias = torch.zeros(
+                query_layer.shape[0],
+                1,
+                query_layer.shape[2],
+                key_layer.shape[2],
+                dtype=query_layer.dtype,
+                device=query_layer.device,
+            )
+            temp_mask = (
+                torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
+                .tril(diagonal=0)
+                .expand(query_layer.shape[0], 1, -1, -1)
+            )
+            attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
         else:
-            # Raw attention scores
-            query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
-            key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
-            value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
-
-            scale = 1.0 / self.norm_factor
-            if self.coeff is not None:
-                scale = scale * self.coeff
-
-            flash_attention_mask = None
-            attn_mask_type = None
-            if attention_mask is None:
-                attn_mask_type = AttnMaskType.causal
-            else:
-                flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
-                if not torch.all(flash_attention_mask):
-                    attn_mask_type = AttnMaskType.paddedcausal
-
-            attention = ColoAttention(
-                embed_dim=self.hidden_size_per_partition,
-                num_heads=self.num_attention_heads_per_partition,
-                dropout=self.attention_dropout.p,
-                scale=scale,
-            )
-            context_layer = attention(
-                query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
-            )
-
-            context_layer = context_layer.permute(1, 0, -1).contiguous()
-
+            attention_mask_type = AttnMaskType.CUSTOM
+            if attention_mask is not None:
+                attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
+                attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
+        dropout_p = self.attention_dropout.p if self.training else 0.0
+        context_layer = ColoAttention.attention(
+            query_layer,
+            key_layer,
+            value_layer,
+            attention_mask=attn_bias,
+            attention_mask_type=attention_mask_type,
+            dropout_p=dropout_p,
+        )
+        context_layer = context_layer.permute(2, 0, 1, 3)
+        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+        context_layer = context_layer.reshape(*new_context_layer_shape)
         return context_layer
 
     return forward
@@ -169,11 +156,17 @@ class ChatGLMPipelineForwards:
         if self.pre_seq_len is not None:
             if past_key_values is None:
                 past_key_values = self.get_prompt(
-                    batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
+                    batch_size=batch_size,
+                    device=input_ids.device,
+                    dtype=inputs_embeds.dtype,
                 )
             if attention_mask is not None:
                 attention_mask = torch.cat(
-                    [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
+                    [
+                        attention_mask.new_ones((batch_size, self.pre_seq_len)),
+                        attention_mask,
+                    ],
+                    dim=-1,
                 )
         if full_attention_mask is None:
             if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
@@ -200,7 +193,9 @@ class ChatGLMPipelineForwards:
 
         if shard_config.enable_sequence_parallelism:
             hidden_states = split_forward_gather_backward(
-                hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
+                hidden_states,
+                dim=0,
+                process_group=shard_config.tensor_parallel_process_group,
             )
         for idx in range(start_idx, end_idx):
             layer = self.encoder._get_layer(idx)
@@ -208,7 +203,12 @@ class ChatGLMPipelineForwards:
                 all_hidden_states = all_hidden_states + (hidden_states,)
             if self.encoder.gradient_checkpointing and self.encoder.training:
                 layer_ret = torch.utils.checkpoint.checkpoint(
-                    layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache
+                    layer,
+                    hidden_states,
+                    attention_mask,
+                    rotary_pos_emb,
+                    past_key_values[idx],
+                    use_cache,
                 )
             else:
                 layer_ret = layer(
@@ -224,7 +224,9 @@ class ChatGLMPipelineForwards:
 
         if shard_config.enable_sequence_parallelism:
             hidden_states = gather_forward_split_backward(
-                hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
+                hidden_states,
+                dim=0,
+                process_group=shard_config.tensor_parallel_process_group,
             )
         if output_hidden_states:
             all_hidden_states = all_hidden_states + (hidden_states,)
@@ -234,7 +236,14 @@ class ChatGLMPipelineForwards:
                 hidden_states = self.encoder.final_layernorm(hidden_states)
             if not return_dict:
                 return tuple(
-                    v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
+                    v
+                    for v in [
+                        hidden_states,
+                        presents,
+                        all_hidden_states,
+                        all_self_attentions,
+                    ]
+                    if v is not None
                 )
             return BaseModelOutputWithPast(
                 last_hidden_state=hidden_states,
@@ -368,7 +377,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
         # Run encoder.
         # [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
         inputs_embeds = split_forward_gather_backward(
-            inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group
+            inputs_embeds,
+            dim=0,
+            process_group=shard_config.tensor_parallel_process_group,
         )
         hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
             inputs_embeds,
@@ -380,7 +391,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
         )
 
         hidden_states = gather_forward_split_backward(
-            hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
+            hidden_states,
+            dim=0,
+            process_group=shard_config.tensor_parallel_process_group,
         )
 
         if not return_dict:
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index 407338b16..72f923bf0 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -21,12 +21,82 @@ from transformers.models.gpt2.modeling_gpt2 import (
 from transformers.utils import logging
 
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
 from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
 from colossalai.shardformer.shard import ShardConfig
 
 from ..layer import cross_entropy_1d
 from ..layer._operation import gather_forward_split_backward
 
+logger = logging.get_logger(__name__)
+
+
+def _get_attention_mask(
+    self: GPT2Model,
+    shard_config: ShardConfig,
+    hidden_states: torch.Tensor,
+    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
+    attention_mask: Optional[torch.FloatTensor],
+    encoder_hidden_states: Optional[torch.Tensor],
+    encoder_attention_mask: Optional[torch.FloatTensor],
+) -> Tuple[Optional[Union[torch.Tensor, dict]], Optional[Union[torch.Tensor, dict]]]:
+    batch_size, seq_len = hidden_states.shape[:2]
+    # If a 2D or 3D attention mask is provided for the cross-attention
+    # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+    if self.config.add_cross_attention and encoder_hidden_states is not None:
+        encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+        if shard_config.enable_flash_attention:
+            encoder_attention_mask = ColoAttention.prepare_attn_kwargs(
+                (encoder_batch_size, 1, seq_len, encoder_sequence_length),
+                dtype=hidden_states.dtype,
+                dtype2=encoder_hidden_states.dtype,
+                q_padding_mask=attention_mask,
+                kv_padding_mask=encoder_attention_mask,
+            )
+        else:
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+            if encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=encoder_hidden_states.device)
+            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+    else:
+        if shard_config.enable_flash_attention:
+            encoder_attention_mask = {"attention_mask": None}
+        else:
+            encoder_attention_mask = None
+    # GPT2Attention mask.
+    past_key_values_length = 0
+    if past_key_values is not None and past_key_values[0] is not None:
+        past_key_values_length = past_key_values[0][0].shape[2]
+    if shard_config.enable_flash_attention:
+        if attention_mask is not None:
+            attention_mask = attention_mask.view(batch_size, -1)
+        attention_mask = ColoAttention.prepare_attn_kwargs(
+            (batch_size, 1, seq_len, seq_len + past_key_values_length),
+            hidden_states.dtype,
+            hidden_states.device,
+            attention_mask,
+            is_causal=True,
+        )
+    elif attention_mask is not None:
+        if batch_size <= 0:
+            raise ValueError("batch_size has to be defined and > 0")
+        attention_mask = attention_mask.view(batch_size, -1)
+        # We create a 3D attention mask from a 2D tensor mask.
+        # Sizes are [batch_size, 1, 1, to_seq_length]
+        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+        # this attention mask is more simple than the triangular masking of causal attention
+        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+        attention_mask = attention_mask[:, None, None, :]
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and the dtype's smallest value for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+        attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+    return attention_mask, encoder_attention_mask
+
 
 class GPT2PipelineForwards:
     """
@@ -83,10 +153,10 @@ class GPT2PipelineForwards:
             elif input_ids is not None:
                 input_shape = input_ids.size()
                 input_ids = input_ids.view(-1, input_shape[-1])
-                batch_size = input_ids.shape[0]
+                input_ids.shape[0]
             elif inputs_embeds is not None:
                 input_shape = inputs_embeds.size()[:-1]
-                batch_size = inputs_embeds.shape[0]
+                inputs_embeds.shape[0]
             else:
                 raise ValueError("You have to specify either input_ids or inputs_embeds")
 
@@ -99,38 +169,7 @@ class GPT2PipelineForwards:
             input_shape = hidden_states.size()[:-1]
             device = hidden_states.device
             hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
-            batch_size = hidden_states.shape[0]
-
-        # GPT2Attention mask.
-        if attention_mask is not None:
-            if batch_size <= 0:
-                raise ValueError("batch_size has to be defined and > 0")
-            attention_mask = attention_mask.view(batch_size, -1)
-            # We create a 3D attention mask from a 2D tensor mask.
-            # Sizes are [batch_size, 1, 1, to_seq_length]
-            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
-            # this attention mask is more simple than the triangular masking of causal attention
-            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
-            attention_mask = attention_mask[:, None, None, :]
-
-            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
-            # masked positions, this operation will create a tensor which is 0.0 for
-            # positions we want to attend and the dtype's smallest value for masked positions.
-            # Since we are adding it to the raw scores before the softmax, this is
-            # effectively the same as removing these entirely.
-            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
-            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
-        # If a 2D or 3D attention mask is provided for the cross-attention
-        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
-        if self.config.add_cross_attention and encoder_hidden_states is not None:
-            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
-            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
-            if encoder_attention_mask is None:
-                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
-            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
-        else:
-            encoder_attention_mask = None
+            hidden_states.shape[0]
 
         # Prepare head mask if needed
         # 1.0 in head_mask indicate we keep the head
@@ -156,6 +195,16 @@ class GPT2PipelineForwards:
 
         output_shape = input_shape + (hidden_states.size(-1),)
 
+        attention_mask, encoder_attention_mask = _get_attention_mask(
+            self,
+            shard_config,
+            hidden_states,
+            past_key_values,
+            attention_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+        )
+
         if self.gradient_checkpointing and self.training:
             if use_cache:
                 logger.warning_once(
@@ -171,7 +220,9 @@ class GPT2PipelineForwards:
         # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
         if shard_config.enable_sequence_parallelism:
             hidden_states = split_forward_gather_backward(
-                hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                hidden_states,
+                dim=1,
+                process_group=shard_config.tensor_parallel_process_group,
             )
 
         # Going through held blocks.
@@ -180,7 +231,7 @@ class GPT2PipelineForwards:
             block = self.h[i]
             torch.cuda.set_device(hidden_states.device)
             # Ensure that attention_mask is always on the same device as hidden_states
-            if attention_mask is not None:
+            if torch.is_tensor(attention_mask):
                 attention_mask = attention_mask.to(hidden_states.device)
             if isinstance(head_mask, torch.Tensor):
                 head_mask = head_mask.to(hidden_states.device)
@@ -229,7 +280,9 @@ class GPT2PipelineForwards:
         # When sequence parallelism done, gather the output tensor in forward and split it in backward
         if shard_config.enable_sequence_parallelism:
             hidden_states = gather_forward_split_backward(
-                hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                hidden_states,
+                dim=1,
+                process_group=shard_config.tensor_parallel_process_group,
             )
 
         if stage_manager.is_last_stage():
@@ -245,7 +298,13 @@ class GPT2PipelineForwards:
             if not return_dict:
                 return tuple(
                     v
-                    for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+                    for v in [
+                        hidden_states,
+                        presents,
+                        all_hidden_states,
+                        all_self_attentions,
+                        all_cross_attentions,
+                    ]
                     if v is not None
                 )
 
@@ -333,7 +392,9 @@ class GPT2PipelineForwards:
             shift_labels = shift_labels.view(-1)
             if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
                 loss = cross_entropy_1d(
-                    shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
+                    shift_logits,
+                    shift_labels,
+                    process_group=shard_config.tensor_parallel_process_group,
                 )
             else:
                 loss = loss_fct(shift_logits, shift_labels)
@@ -733,27 +794,18 @@ class GPT2PipelineForwards:
 def get_gpt2_flash_attention_forward():
     from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
 
-    from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
-    def split_heads(tensor, num_heads, attn_head_size):
-        """
-        Splits hidden_size dim into attn_head_size and num_heads
-        """
-        new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
-        tensor = tensor.view(new_shape)
-        return tensor
-
     def forward(
         self: GPT2Attention,
         hidden_states: Optional[Tuple[torch.FloatTensor]],
         layer_past: Optional[Tuple[torch.Tensor]] = None,
-        attention_mask: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[dict] = None,
         head_mask: Optional[torch.FloatTensor] = None,
         encoder_hidden_states: Optional[torch.Tensor] = None,
-        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[dict] = None,
         use_cache: Optional[bool] = False,
         output_attentions: Optional[bool] = False,
     ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+        assert head_mask is None, "FlashAttention does not support head_mask"
         if encoder_hidden_states is not None:
             if not hasattr(self, "q_attn"):
                 raise ValueError(
@@ -766,10 +818,9 @@ def get_gpt2_flash_attention_forward():
             attention_mask = encoder_attention_mask
         else:
             query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
-
-        query = split_heads(query, self.num_heads, self.head_dim)
-        key = split_heads(key, self.num_heads, self.head_dim)
-        value = split_heads(value, self.num_heads, self.head_dim)
+        query = self._split_heads(query, self.num_heads, self.head_dim)
+        key = self._split_heads(key, self.num_heads, self.head_dim)
+        value = self._split_heads(value, self.num_heads, self.head_dim)
 
         if layer_past is not None:
             past_key, past_value = layer_past
@@ -781,29 +832,14 @@ def get_gpt2_flash_attention_forward():
         else:
             present = None
 
-        if not self.is_cross_attention:
-            attn_mask_type = AttnMaskType.causal
-            flash_attention_mask = None
-        if attention_mask != None:
-            flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
-            if not torch.all(flash_attention_mask):
-                if attn_mask_type == AttnMaskType.causal:
-                    attn_mask_type == AttnMaskType.paddedcausal
-                else:
-                    attn_mask_type = AttnMaskType.padding
-
-        scale = value.size(-1) ** -0.5
+        scale = 1.0
+        if self.scale_attn_weights:
+            scale /= value.size(-1) ** 0.5
         if self.scale_attn_by_inverse_layer_idx:
-            scale = scale * (1 / float(self.layer_idx + 1))
-
-        # use coloattention
-        if not hasattr(self, "attention"):
-            self.attention = ColoAttention(
-                embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.attn_dropout.p, scale=scale
-            )
-
-        attn_output = self.attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
-
+            scale /= float(self.layer_idx + 1)
+        dropout_p = self.attn_dropout.p if self.training else 0.0
+        attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p, scale=scale)
+        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
         attn_output = self.c_proj(attn_output)
         attn_output = self.resid_dropout(attn_output)
         outputs = (attn_output, present, None)
@@ -813,6 +849,195 @@ def get_gpt2_flash_attention_forward():
     return forward
 
 
+def get_gpt_model_forward_for_flash_attn(shard_config: ShardConfig):
+    def forward(
+        self: GPT2Model,
+        input_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+            input_ids.shape[0]
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            inputs_embeds.shape[0]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+        if token_type_ids is not None:
+            token_type_ids = token_type_ids.view(-1, input_shape[-1])
+        if position_ids is not None:
+            position_ids = position_ids.view(-1, input_shape[-1])
+
+        if past_key_values is None:
+            past_length = 0
+            past_key_values = tuple([None] * len(self.h))
+        else:
+            past_length = past_key_values[0][0].size(-2)
+        if position_ids is None:
+            position_ids = torch.arange(
+                past_length,
+                input_shape[-1] + past_length,
+                dtype=torch.long,
+                device=device,
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # head_mask has shape n_layer x batch x n_heads x N x N
+        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.wte(input_ids)
+        position_embeds = self.wpe(position_ids)
+        hidden_states = inputs_embeds + position_embeds
+
+        if token_type_ids is not None:
+            token_type_embeds = self.wte(token_type_ids)
+            hidden_states = hidden_states + token_type_embeds
+
+        hidden_states = self.drop(hidden_states)
+
+        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+        attention_mask, encoder_attention_mask = _get_attention_mask(
+            self,
+            shard_config,
+            hidden_states,
+            past_key_values,
+            attention_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+        )
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        presents = () if use_cache else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+        all_hidden_states = () if output_hidden_states else None
+        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+            # Model parallel
+            if self.model_parallel:
+                torch.cuda.set_device(hidden_states.device)
+                # Ensure layer_past is on same device as hidden_states (might not be correct)
+                if layer_past is not None:
+                    layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+                # Ensure that attention_mask is always on the same device as hidden_states
+                if torch.is_tensor(attention_mask):
+                    attention_mask = attention_mask.to(hidden_states.device)
+                if isinstance(head_mask, torch.Tensor):
+                    head_mask = head_mask.to(hidden_states.device)
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, use_cache, output_attentions)
+
+                    return custom_forward
+
+                outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(block),
+                    hidden_states,
+                    None,
+                    attention_mask,
+                    head_mask[i],
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                )
+            else:
+                outputs = block(
+                    hidden_states,
+                    layer_past=layer_past,
+                    attention_mask=attention_mask,
+                    head_mask=head_mask[i],
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    use_cache=use_cache,
+                    output_attentions=output_attentions,
+                )
+
+            hidden_states = outputs[0]
+            if use_cache is True:
+                presents = presents + (outputs[1],)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+                if self.config.add_cross_attention:
+                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+            # Model Parallel: If it's the last layer for that device, put things on the next device
+            if self.model_parallel:
+                for k, v in self.device_map.items():
+                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
+                        hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+        hidden_states = self.ln_f(hidden_states)
+
+        hidden_states = hidden_states.view(output_shape)
+        # Add last hidden state
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    presents,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=presents,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+    return forward
+
+
 def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
     def forward(
         self,
@@ -842,10 +1067,10 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
         elif input_ids is not None:
             input_shape = input_ids.size()
             input_ids = input_ids.view(-1, input_shape[-1])
-            batch_size = input_ids.shape[0]
+            input_ids.shape[0]
         elif inputs_embeds is not None:
             input_shape = inputs_embeds.size()[:-1]
-            batch_size = inputs_embeds.shape[0]
+            inputs_embeds.shape[0]
         else:
             raise ValueError("You have to specify either input_ids or inputs_embeds")
 
@@ -862,40 +1087,14 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
         else:
             past_length = past_key_values[0][0].size(-2)
         if position_ids is None:
-            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+            position_ids = torch.arange(
+                past_length,
+                input_shape[-1] + past_length,
+                dtype=torch.long,
+                device=device,
+            )
             position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
 
-        # GPT2Attention mask.
-        if attention_mask is not None:
-            if batch_size <= 0:
-                raise ValueError("batch_size has to be defined and > 0")
-            attention_mask = attention_mask.view(batch_size, -1)
-            # We create a 3D attention mask from a 2D tensor mask.
-            # Sizes are [batch_size, 1, 1, to_seq_length]
-            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
-            # this attention mask is more simple than the triangular masking of causal attention
-            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
-            attention_mask = attention_mask[:, None, None, :]
-
-            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
-            # masked positions, this operation will create a tensor which is 0.0 for
-            # positions we want to attend and the dtype's smallest value for masked positions.
-            # Since we are adding it to the raw scores before the softmax, this is
-            # effectively the same as removing these entirely.
-            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
-            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
-        # If a 2D or 3D attention mask is provided for the cross-attention
-        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
-        if self.config.add_cross_attention and encoder_hidden_states is not None:
-            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
-            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
-            if encoder_attention_mask is None:
-                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
-            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
-        else:
-            encoder_attention_mask = None
-
         # Prepare head mask if needed
         # 1.0 in head_mask indicate we keep the head
         # attention_probs has shape bsz x n_heads x N x N
@@ -914,6 +1113,15 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
         hidden_states = self.drop(hidden_states)
 
         output_shape = input_shape + (hidden_states.size(-1),)
+        attention_mask, encoder_attention_mask = _get_attention_mask(
+            self,
+            shard_config,
+            hidden_states,
+            past_key_values,
+            attention_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+        )
 
         if self.gradient_checkpointing and self.training:
             if use_cache:
@@ -931,7 +1139,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
         # split the input tensor along sequence dimension
         # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
         hidden_states = split_forward_gather_backward(
-            hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+            hidden_states,
+            dim=1,
+            process_group=shard_config.tensor_parallel_process_group,
         )
 
         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -942,7 +1152,7 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
                 if layer_past is not None:
                     layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
                 # Ensure that attention_mask is always on the same device as hidden_states
-                if attention_mask is not None:
+                if torch.is_tensor(attention_mask):
                     attention_mask = attention_mask.to(hidden_states.device)
                 if isinstance(head_mask, torch.Tensor):
                     head_mask = head_mask.to(hidden_states.device)
@@ -996,7 +1206,9 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
 
         # When sequence parallelism done, gather the output tensor in forward and split it in backward
         hidden_states = gather_forward_split_backward(
-            hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+            hidden_states,
+            dim=1,
+            process_group=shard_config.tensor_parallel_process_group,
         )
 
         hidden_states = self.ln_f(hidden_states)
@@ -1008,7 +1220,13 @@ def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
         if not return_dict:
             return tuple(
                 v
-                for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+                for v in [
+                    hidden_states,
+                    presents,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
                 if v is not None
             )
 
diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py
index 1990d7df3..5c254d1e7 100644
--- a/colossalai/shardformer/modeling/gptj.py
+++ b/colossalai/shardformer/modeling/gptj.py
@@ -19,9 +19,54 @@ from transformers.models.gptj.modeling_gptj import (
 from transformers.utils import is_torch_fx_proxy, logging
 
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
 from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
 from colossalai.shardformer.shard import ShardConfig
 
+logger = logging.get_logger(__name__)
+
+
+def _get_attention_mask(
+    self: GPTJModel,
+    shard_config: ShardConfig,
+    hidden_states: torch.Tensor,
+    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
+    attention_mask: Optional[torch.FloatTensor],
+) -> Optional[Union[torch.Tensor, dict]]:
+    batch_size, seq_len = hidden_states.shape[:2]
+    past_key_values_length = 0
+    if past_key_values is not None and past_key_values[0] is not None:
+        past_key_values_length = past_key_values[0][0].shape[2]
+    if shard_config.enable_flash_attention:
+        if attention_mask is not None:
+            attention_mask = attention_mask.view(batch_size, -1)
+        attention_mask = ColoAttention.prepare_attn_kwargs(
+            (batch_size, 1, seq_len, seq_len + past_key_values_length),
+            hidden_states.dtype,
+            hidden_states.device,
+            attention_mask,
+            is_causal=True,
+        )
+    elif attention_mask is not None:
+        if batch_size <= 0:
+            raise ValueError("batch_size has to be defined and > 0")
+        attention_mask = attention_mask.view(batch_size, -1)
+        # We create a 3D attention mask from a 2D tensor mask.
+        # Sizes are [batch_size, 1, 1, to_seq_length]
+        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+        # this attention mask is more simple than the triangular masking of causal attention
+        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+        attention_mask = attention_mask[:, None, None, :]
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and the dtype's smallest value for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+        attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+    return attention_mask
+
 
 class GPTJPipelineForwards:
     """
@@ -96,26 +141,6 @@ class GPTJPipelineForwards:
             batch_size, seq_length = input_shape[0], input_shape[1]
             device = hidden_states.device
 
-        # Attention mask.
-        if attention_mask is not None:
-            if batch_size <= 0:
-                raise ValueError("batch_size has to be defined and > 0")
-            attention_mask = attention_mask.view(batch_size, -1)
-            # We create a 3D attention mask from a 2D tensor mask.
-            # Sizes are [batch_size, 1, 1, to_seq_length]
-            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
-            # this attention mask is more simple than the triangular masking of causal attention
-            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
-            attention_mask = attention_mask[:, None, None, :]
-
-            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
-            # masked positions, this operation will create a tensor which is 0.0 for
-            # positions we want to attend and the dtype's smallest value for masked positions.
-            # Since we are adding it to the raw scores before the softmax, this is
-            # effectively the same as removing these entirely.
-            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
-            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
         # Prepare head mask if needed
         # 1.0 in head_mask indicate we keep the head
         # attention_probs has shape bsz x num_attention_heads x N x N
@@ -139,6 +164,8 @@ class GPTJPipelineForwards:
 
         output_shape = input_shape + (hidden_states.size(-1),)
 
+        attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
+
         if self.gradient_checkpointing and self.training:
             if use_cache:
                 logger.warning_once(
@@ -154,7 +181,9 @@ class GPTJPipelineForwards:
         # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
         if shard_config.enable_sequence_parallelism:
             hidden_states = split_forward_gather_backward(
-                hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                hidden_states,
+                dim=1,
+                process_group=shard_config.tensor_parallel_process_group,
             )
 
         # Going through held blocks.
@@ -209,7 +238,9 @@ class GPTJPipelineForwards:
         # When sequence parallelism done, gather the output tensor in forward and split it in backward
         if shard_config.enable_sequence_parallelism:
             hidden_states = gather_forward_split_backward(
-                hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                hidden_states,
+                dim=1,
+                process_group=shard_config.tensor_parallel_process_group,
             )
 
         if stage_manager.is_last_stage():
@@ -223,7 +254,14 @@ class GPTJPipelineForwards:
         if stage_manager.is_last_stage():
             if not return_dict:
                 return tuple(
-                    v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
+                    v
+                    for v in [
+                        hidden_states,
+                        presents,
+                        all_hidden_states,
+                        all_self_attentions,
+                    ]
+                    if v is not None
                 )
 
             return BaseModelOutputWithPast(
@@ -530,24 +568,11 @@ class GPTJPipelineForwards:
 def get_gptj_flash_attention_forward():
     from transformers.models.gptj.modeling_gptj import GPTJAttention
 
-    from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
-    def split_heads(tensor, num_attention_heads, attn_head_size, rotary):
-        """
-        Splits hidden dim into attn_head_size and num_attention_heads
-        """
-        new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
-        tensor = tensor.view(new_shape)
-        if rotary or len(tensor.shape) in [4, 5]:
-            return tensor
-        else:
-            raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
-
     def forward(
         self: GPTJAttention,
         hidden_states: torch.FloatTensor,
         layer_past: Optional[Tuple[torch.Tensor]] = None,
-        attention_mask: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[dict] = None,
         position_ids: Optional[torch.LongTensor] = None,
         head_mask: Optional[torch.FloatTensor] = None,
         use_cache: Optional[bool] = False,
@@ -556,13 +581,14 @@ def get_gptj_flash_attention_forward():
         Tuple[torch.Tensor, Tuple[torch.Tensor]],
         Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
     ]:
+        assert head_mask is None, "head_mask is not supported for FlashAttention"
         query = self.q_proj(hidden_states)
         key = self.k_proj(hidden_states)
         value = self.v_proj(hidden_states)
 
-        query = split_heads(query, self.num_attention_heads, self.head_dim, True)
-        key = split_heads(key, self.num_attention_heads, self.head_dim, True)
-        value = split_heads(value, self.num_attention_heads, self.head_dim, False)
+        query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
+        key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
+        value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
 
         if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
             # The logic to conditionally copy to GPU could not be traced, so we do this
@@ -591,41 +617,23 @@ def get_gptj_flash_attention_forward():
             key = apply_rotary_pos_emb(key, sin, cos)
             query = apply_rotary_pos_emb(query, sin, cos)
 
-        # key = key.permute(0, 2, 1, 3)
-        # query = query.permute(0, 2, 1, 3)
-        key = key.to(dtype=value.dtype)  # fp16 compatibility
-        query = query.to(dtype=value.dtype)
+        key = key.permute(0, 2, 1, 3)
+        query = query.permute(0, 2, 1, 3)
 
         if layer_past is not None:
             past_key = layer_past[0]
             past_value = layer_past[1]
-            key = torch.cat((past_key, key), dim=1)
-            value = torch.cat((past_value, value), dim=1)
+            key = torch.cat((past_key, key), dim=-2)
+            value = torch.cat((past_value, value), dim=-2)
 
         if use_cache is True:
             present = (key, value)
         else:
             present = None
 
-        # use AttnMaskType and ColoAttention
-        attn_mask_type = AttnMaskType.causal
-        flash_attention_mask = None
-        if attention_mask != None:
-            if attn_mask_type == AttnMaskType.causal:
-                attn_mask_type == AttnMaskType.paddedcausal
-            else:
-                attn_mask_type = AttnMaskType.padding
-            flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
-
-        # use coloattention
-        scale = value.size(-1) ** -0.5
-
-        attention = ColoAttention(
-            embed_dim=self.embed_dim, num_heads=self.num_attention_heads, dropout=self.attn_dropout.p, scale=scale
-        )
-
-        attn_output = attention(query, key, value, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type)
-
+        dropout_p = self.attn_dropout.p if self.training else 0.0
+        attn_output = ColoAttention.attention(query, key, value, **attention_mask, dropout_p=dropout_p)
+        attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
         attn_output = self.out_proj(attn_output)
         attn_output = self.resid_dropout(attn_output)
         outputs = (attn_output, present, None)
@@ -635,6 +643,180 @@ def get_gptj_flash_attention_forward():
     return forward
 
 
+def gptj_model_forward_for_flash_attention(shard_config: ShardConfig):
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+            input_ids.shape[0]
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            inputs_embeds.shape[0]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+        if token_type_ids is not None:
+            token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+        if position_ids is not None:
+            position_ids = position_ids.view(-1, input_shape[-1]).long()
+
+        if past_key_values is None:
+            past_length = 0
+            past_key_values = tuple([None] * len(self.h))
+        else:
+            past_length = past_key_values[0][0].size(-2)
+
+        if position_ids is None:
+            position_ids = torch.arange(
+                past_length,
+                input_shape[-1] + past_length,
+                dtype=torch.long,
+                device=device,
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x num_attention_heads x N x N
+        # head_mask has shape n_layer x batch x num_attention_heads x N x N
+        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.wte(input_ids)
+
+        hidden_states = inputs_embeds
+
+        if token_type_ids is not None:
+            token_type_embeds = self.wte(token_type_ids)
+            hidden_states = hidden_states + token_type_embeds
+
+        hidden_states = self.drop(hidden_states)
+
+        attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
+
+        output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        presents = () if use_cache else None
+        all_self_attentions = () if output_attentions else None
+        all_hidden_states = () if output_hidden_states else None
+        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+            # Model parallel
+            if self.model_parallel:
+                torch.cuda.set_device(hidden_states.device)
+                # Ensure layer_past is on same device as hidden_states (might not be correct)
+                if layer_past is not None:
+                    layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+                # Ensure that attention_mask is always on the same device as hidden_states
+                if attention_mask is not None:
+                    attention_mask = attention_mask.to(hidden_states.device)
+                if isinstance(head_mask, torch.Tensor):
+                    head_mask = head_mask.to(hidden_states.device)
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, use_cache, output_attentions)
+
+                    return custom_forward
+
+                outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(block),
+                    hidden_states,
+                    None,
+                    attention_mask,
+                    position_ids,
+                    head_mask[i],
+                )
+            else:
+                outputs = block(
+                    hidden_states=hidden_states,
+                    layer_past=layer_past,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    head_mask=head_mask[i],
+                    use_cache=use_cache,
+                    output_attentions=output_attentions,
+                )
+
+            hidden_states = outputs[0]
+            if use_cache is True:
+                presents = presents + (outputs[1],)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+            # Model Parallel: If it's the last layer for that device, put things on the next device
+            if self.model_parallel:
+                for k, v in self.device_map.items():
+                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
+                        hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+        hidden_states = self.ln_f(hidden_states)
+
+        hidden_states = hidden_states.view(output_shape)
+        # Add last hidden state
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    presents,
+                    all_hidden_states,
+                    all_self_attentions,
+                ]
+                if v is not None
+            )
+
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=presents,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+    return forward
+
+
 def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
     def forward(
         self,
@@ -662,10 +844,10 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
         elif input_ids is not None:
             input_shape = input_ids.size()
             input_ids = input_ids.view(-1, input_shape[-1])
-            batch_size = input_ids.shape[0]
+            input_ids.shape[0]
         elif inputs_embeds is not None:
             input_shape = inputs_embeds.size()[:-1]
-            batch_size = inputs_embeds.shape[0]
+            inputs_embeds.shape[0]
         else:
             raise ValueError("You have to specify either input_ids or inputs_embeds")
 
@@ -684,29 +866,14 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
             past_length = past_key_values[0][0].size(-2)
 
         if position_ids is None:
-            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+            position_ids = torch.arange(
+                past_length,
+                input_shape[-1] + past_length,
+                dtype=torch.long,
+                device=device,
+            )
             position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
 
-        # Attention mask.
-        if attention_mask is not None:
-            if batch_size <= 0:
-                raise ValueError("batch_size has to be defined and > 0")
-            attention_mask = attention_mask.view(batch_size, -1)
-            # We create a 3D attention mask from a 2D tensor mask.
-            # Sizes are [batch_size, 1, 1, to_seq_length]
-            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
-            # this attention mask is more simple than the triangular masking of causal attention
-            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
-            attention_mask = attention_mask[:, None, None, :]
-
-            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
-            # masked positions, this operation will create a tensor which is 0.0 for
-            # positions we want to attend and the dtype's smallest value for masked positions.
-            # Since we are adding it to the raw scores before the softmax, this is
-            # effectively the same as removing these entirely.
-            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
-            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
         # Prepare head mask if needed
         # 1.0 in head_mask indicate we keep the head
         # attention_probs has shape bsz x num_attention_heads x N x N
@@ -725,6 +892,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
         hidden_states = self.drop(hidden_states)
 
         output_shape = input_shape + (hidden_states.size(-1),)
+        attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
 
         if self.gradient_checkpointing and self.training:
             if use_cache:
@@ -740,7 +908,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
         # split the input tensor along sequence dimension
         # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
         hidden_states = split_forward_gather_backward(
-            hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+            hidden_states,
+            dim=1,
+            process_group=shard_config.tensor_parallel_process_group,
         )
 
         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -801,7 +971,9 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
 
         # When sequence parallelism done, gather the output tensor in forward and split it in backward
         hidden_states = gather_forward_split_backward(
-            hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+            hidden_states,
+            dim=1,
+            process_group=shard_config.tensor_parallel_process_group,
         )
 
         hidden_states = self.ln_f(hidden_states)
@@ -812,7 +984,16 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
             all_hidden_states = all_hidden_states + (hidden_states,)
 
         if not return_dict:
-            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    presents,
+                    all_hidden_states,
+                    all_self_attentions,
+                ]
+                if v is not None
+            )
 
         return BaseModelOutputWithPast(
             last_hidden_state=hidden_states,
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index d5e02b64c..1f17144f5 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -15,7 +15,9 @@ from transformers.utils import logging
 from colossalai.pipeline.stage_manager import PipelineStageManager
 from colossalai.shardformer.shard import ShardConfig
 
-from ..layer import cross_entropy_1d
+
+from ..layer import ColoAttention, cross_entropy_1d
+
 
 try:
     from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -105,18 +107,25 @@ class LlamaPipelineForwards:
 
         # embed positions, for the first stage, hidden_states is the input embeddings,
         # for the other stages, hidden_states is the output of the previous stage
-        if attention_mask is None:
-            attention_mask = torch.ones(
-                (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
-            )
-        if LATEST_VERSION:
-            attention_mask = _prepare_4d_causal_attention_mask(
-                attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+        if shard_config.enable_flash_attention:
+            # in this case, attention_mask is a dict rather than a tensor
+            mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+            attention_mask = ColoAttention.prepare_attn_kwargs(
+                mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
             )
         else:
-            attention_mask = self._prepare_decoder_attention_mask(
-                attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
-            )
+            if attention_mask is None:
+                attention_mask = torch.ones(
+                    (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
+                )
+            if LATEST_VERSION:
+                attention_mask = _prepare_4d_causal_attention_mask(
+                    attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+                )
+            else:
+                attention_mask = self._prepare_decoder_attention_mask(
+                    attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+                )
 
         if self.gradient_checkpointing and self.training:
             if use_cache:
@@ -262,6 +271,7 @@ class LlamaPipelineForwards:
             stage_manager=stage_manager,
             hidden_states=hidden_states,
             stage_index=stage_index,
+            shard_config=shard_config,
         )
         past_key_values = None
 
@@ -352,6 +362,7 @@ class LlamaPipelineForwards:
             stage_manager=stage_manager,
             hidden_states=hidden_states,
             stage_index=stage_index,
+            shard_config=shard_config,
         )
 
         if input_ids is not None:
@@ -420,8 +431,6 @@ class LlamaPipelineForwards:
 def get_llama_flash_attention_forward(shard_config: ShardConfig):
     from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
 
-    from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
     llama_version = 2
     try:
         from transformers.models.llama.modeling_llama import repeat_kv
@@ -432,7 +441,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
     def forward(
         self: LlamaAttention,
         hidden_states: torch.Tensor,
-        attention_mask: Optional[torch.Tensor] = None,
+        attention_mask: Optional[dict] = None,
         position_ids: Optional[torch.LongTensor] = None,
         past_key_value: Optional[Tuple[torch.Tensor]] = None,
         output_attentions: bool = False,
@@ -466,31 +475,10 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
             key_states = repeat_kv(key_states, self.num_key_value_groups)
             value_states = repeat_kv(value_states, self.num_key_value_groups)
 
-        me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
-        query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)
-        key_states = key_states.transpose(1, 2).contiguous().view(*me_input_shape)
-        value_states = value_states.transpose(1, 2).contiguous().view(*me_input_shape)
-
-        flash_attention_mask = None
-        attn_mask_type = AttnMaskType.causal
-        if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
-            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
-                raise ValueError(
-                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
-                )
-            flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
-            attn_mask_type = AttnMaskType.paddedcausal
-
-        if not hasattr(self, "attention"):
-            self.attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
-        attn_output = self.attention(
-            query_states,
-            key_states,
-            value_states,
-            attn_mask=flash_attention_mask,
-            attn_mask_type=attn_mask_type,
-            origin_attn_mask=attention_mask,
-        )
+        assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
+        attn_output = ColoAttention.attention(query_states, key_states, value_states, **attention_mask)
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
         attn_output = self.o_proj(attn_output)
 
@@ -499,6 +487,137 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
     return forward
 
 
+def get_llama_model_forward_for_flash_attn(shard_config: ShardConfig):
+    logger = logging.get_logger(__name__)
+    assert shard_config.enable_flash_attention, "Flash Attention is not enabled."
+
+    def forward(
+        self: LlamaModel,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            batch_size, seq_length = input_ids.shape
+        elif inputs_embeds is not None:
+            batch_size, seq_length, _ = inputs_embeds.shape
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        if past_key_values is not None:
+            past_key_values_length = past_key_values[0][0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+
+        if position_ids is None:
+            device = input_ids.device if input_ids is not None else inputs_embeds.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+        # embed positions
+        hidden_states = inputs_embeds
+
+        # in this case, attention_mask is a dict rather than a tensor
+        mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
+        attention_mask = ColoAttention.prepare_attn_kwargs(
+            mask_shape, hidden_states.dtype, hidden_states.device, q_padding_mask=attention_mask, is_causal=True
+        )
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = () if use_cache else None
+
+        for idx, decoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, past_key_value, output_attentions)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(decoder_layer),
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+    return forward
+
+
 def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
     from transformers import LlamaForCausalLM
 
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index d0e267eac..a26526430 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -18,6 +18,37 @@ from transformers.models.opt.modeling_opt import (
 from transformers.utils import logging
 
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
+from colossalai.shardformer.shard import ShardConfig
+
+logger = logging.get_logger(__name__)
+
+
+def _get_attention_mask(
+    self: OPTModel,
+    shard_config: ShardConfig,
+    hidden_states: torch.Tensor,
+    past_key_values_length: int,
+    attention_mask: Optional[torch.FloatTensor],
+):
+    batch_size, seq_length = hidden_states.shape[:2]
+    mask_seq_length = past_key_values_length + seq_length
+    if shard_config.enable_flash_attention:
+        attention_mask = ColoAttention.prepare_attn_kwargs(
+            (batch_size, 1, seq_length, mask_seq_length),
+            hidden_states.dtype,
+            hidden_states.device,
+            attention_mask,
+            is_causal=True,
+        )
+    else:
+        attention_mask = self.decoder._prepare_decoder_attention_mask(
+            attention_mask,
+            (batch_size, seq_length),
+            hidden_states,
+            past_key_values_length,
+        )
+    return attention_mask
 
 
 class OPTPipelineForwards:
@@ -26,46 +57,6 @@ class OPTPipelineForwards:
     under pipeline setting.
     """
 
-    @staticmethod
-    def _prepare_decoder_attention_mask(attention_mask, input_shape, _dtype, device, past_key_values_length):
-        # create causal mask
-        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
-        from transformers.models.opt.modeling_opt import _make_causal_mask
-
-        combined_attention_mask = None
-        if input_shape[-1] > 1:
-            combined_attention_mask = _make_causal_mask(
-                input_shape,
-                _dtype,
-                device,
-                past_key_values_length=past_key_values_length,
-            )
-
-        if attention_mask is not None:
-            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
-            expanded_attn_mask = OPTPipelineForwards._expand_mask(attention_mask, _dtype, tgt_len=input_shape[-1]).to(
-                device
-            )
-            combined_attention_mask = (
-                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
-            )
-
-        return combined_attention_mask
-
-    @staticmethod
-    def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
-        """
-        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
-        """
-        bsz, src_len = mask.size()
-        tgt_len = tgt_len if tgt_len is not None else src_len
-
-        expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
-        inverted_mask = 1.0 - expanded_mask
-
-        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
-
     @staticmethod
     def opt_model_forward(
         self: OPTModel,
@@ -81,6 +72,7 @@ class OPTPipelineForwards:
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
+        shard_config: Optional[ShardConfig] = None,
     ) -> Union[Tuple, BaseModelOutputWithPast]:
         """
         This forward method is modified based on transformers.models.opt.modeling_opt.OPTModel.forward
@@ -119,7 +111,7 @@ class OPTPipelineForwards:
             if decoder.project_in is not None:
                 inputs_embeds = decoder.project_in(inputs_embeds)
             device = input_ids.device if input_ids is not None else inputs_embeds.device
-            _dtype = inputs_embeds.dtype
+            inputs_embeds.dtype
 
         else:
             if hidden_states is None:
@@ -127,7 +119,7 @@ class OPTPipelineForwards:
             input_shape = hidden_states.size()[:-1]
             batch_size, seq_length = input_shape[0], input_shape[1]
             device = hidden_states.device
-            _dtype = hidden_states.dtype
+            hidden_states.dtype
 
         past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
         # required mask seq length can be calculated via length of past
@@ -141,13 +133,24 @@ class OPTPipelineForwards:
                 f"{mask_seq_length} (sum of the lengths of current and past inputs)"
             )
 
-        causal_attention_mask = OPTPipelineForwards._prepare_decoder_attention_mask(
-            attention_mask, input_shape, _dtype, device, past_key_values_length
-        )
-
         if stage_manager.is_first_stage():
+            causal_attention_mask = _get_attention_mask(
+                self,
+                shard_config,
+                inputs_embeds,
+                past_key_values_length,
+                attention_mask,
+            )
             pos_embeds = decoder.embed_positions(attention_mask, past_key_values_length)
             hidden_states = inputs_embeds + pos_embeds
+        else:
+            causal_attention_mask = _get_attention_mask(
+                self,
+                shard_config,
+                hidden_states,
+                past_key_values_length,
+                attention_mask,
+            )
 
         if decoder.gradient_checkpointing and decoder.training:
             if use_cache:
@@ -249,7 +252,16 @@ class OPTPipelineForwards:
 
         if stage_manager.is_last_stage():
             if not return_dict:
-                return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+                return tuple(
+                    v
+                    for v in [
+                        hidden_states,
+                        next_cache,
+                        all_hidden_states,
+                        all_self_attns,
+                    ]
+                    if v is not None
+                )
 
             return BaseModelOutputWithPast(
                 last_hidden_state=hidden_states,
@@ -276,6 +288,7 @@ class OPTPipelineForwards:
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
+        shard_config: Optional[ShardConfig] = None,
     ) -> Union[Tuple, CausalLMOutputWithPast]:
         r"""
         This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForCausalLM.forward.
@@ -303,6 +316,7 @@ class OPTPipelineForwards:
             stage_manager=stage_manager,
             hidden_states=hidden_states,
             stage_index=stage_index,
+            shard_config=shard_config,
         )
         if stage_manager.is_last_stage():
             logits = self.lm_head(outputs[0]).contiguous()
@@ -347,6 +361,7 @@ class OPTPipelineForwards:
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
+        shard_config: Optional[ShardConfig] = None,
     ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
         r"""
         This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForSequenceClassification.forward.
@@ -371,6 +386,7 @@ class OPTPipelineForwards:
             stage_manager=stage_manager,
             hidden_states=hidden_states,
             stage_index=stage_index,
+            shard_config=shard_config,
         )
 
         if stage_manager.is_last_stage():
@@ -448,6 +464,7 @@ class OPTPipelineForwards:
         stage_manager: Optional[PipelineStageManager] = None,
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
+        shard_config: Optional[ShardConfig] = None,
     ) -> Union[Tuple, QuestionAnsweringModelOutput]:
         r"""
         This function is modified on the basis of transformers.models.opt.modeling_opt.OPTForQuestionAnswering.forward.
@@ -469,6 +486,7 @@ class OPTPipelineForwards:
             stage_manager=stage_manager,
             hidden_states=hidden_states,
             stage_index=stage_index,
+            shard_config=shard_config,
         )
         if stage_manager.is_last_stage():
             hidden_states = transformer_outputs[0]
@@ -511,49 +529,47 @@ class OPTPipelineForwards:
             return {"hidden_states": hidden_states}
 
 
-def get_opt_flash_attention_forward():
+def get_opt_flash_attention_forward(shard_config: ShardConfig):
     from transformers.models.opt.modeling_opt import OPTAttention
 
-    from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
     def forward(
         self: OPTAttention,
         hidden_states: torch.Tensor,
         key_value_states: Optional[torch.Tensor] = None,
         past_key_value: Optional[Tuple[torch.Tensor]] = None,
-        attention_mask: Optional[torch.Tensor] = None,
+        attention_mask: Optional[dict] = None,
         layer_head_mask: Optional[torch.Tensor] = None,
         output_attentions: bool = False,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         """Input shape: Batch x Time x Channel"""
-
+        assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
         # if key_value_states are provided this layer is used as a cross-attention layer
         # for the decoder
         is_cross_attention = key_value_states is not None
+
         bsz, tgt_len, _ = hidden_states.size()
 
-        attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
         # get query proj
-        query_states = self.q_proj(hidden_states).view(*attention_input_shape)
+        query_states = self.q_proj(hidden_states)
         # get key, value proj
         if is_cross_attention and past_key_value is not None:
-            # reuse k, v, cross_attentions
-            key_states = past_key_value[0].transpose(1, 2).contiguous().view(*attention_input_shape)
-            value_states = past_key_value[1].transpose(1, 2).contiguous().view(*attention_input_shape)
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
         elif is_cross_attention:
             # cross_attentions
-            key_states = self.k_proj(key_value_states).view(*attention_input_shape)
-            value_states = self.v_proj(key_value_states).view(*attention_input_shape)
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
         elif past_key_value is not None:
             # reuse k, v, self_attention
-            key_states = self.k_proj(hidden_states).view(*attention_input_shape)
-            value_states = self.v_proj(hidden_states).view(*attention_input_shape)
-            key_states = torch.cat([past_key_value[0], key_states], dim=1)
-            value_states = torch.cat([past_key_value[1], value_states], dim=1)
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
         else:
             # self_attention
-            key_states = self.k_proj(hidden_states).view(*attention_input_shape)
-            value_states = self.v_proj(hidden_states).view(*attention_input_shape)
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
 
         if self.is_decoder:
             # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -565,38 +581,181 @@ def get_opt_flash_attention_forward():
             # if encoder bi-directional self-attention `past_key_value` is always `None`
             past_key_value = (key_states, value_states)
 
-        src_len = key_states.size(1)
-        if layer_head_mask != None:
-            if layer_head_mask.size() != (self.num_heads,):
-                raise ValueError(
-                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
-                    f" {layer_head_mask.size()}"
-                )
+        query_states = self._shape(query_states, tgt_len, bsz)
 
-        flash_attention_mask = None
-        attn_mask_type = AttnMaskType.causal
-        if attention_mask != None:
-            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
-                raise ValueError(
-                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
-                )
-            flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
-            if not torch.all(flash_attention_mask):
-                attn_mask_type = AttnMaskType.paddedcausal
+        dropout_p = self.dropout if self.training else 0.0
+        attn_output = ColoAttention.attention(
+            query_states,
+            key_states,
+            value_states,
+            **attention_mask,
+            dropout_p=dropout_p,
+            scale=self.scaling,
+        )
 
-        attention = ColoAttention(
-            embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
-        )
-        attn_output = attention(
-            query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
-        )
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned aross GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
 
         attn_output = self.out_proj(attn_output)
+
         return attn_output, None, past_key_value
 
     return forward
 
 
+def get_opt_decoder_forward_for_flash_attention(shard_config: ShardConfig):
+    from transformers.models.opt.modeling_opt import OPTDecoder
+
+    def forward(
+        self: OPTDecoder,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        batch_size, seq_length = input_shape
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+        # required mask seq length can be calculated via length of past
+        mask_seq_length = past_key_values_length + seq_length
+
+        # embed positions
+        if attention_mask is None:
+            attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+        elif attention_mask.shape[1] != mask_seq_length:
+            raise ValueError(
+                f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
+                f"{mask_seq_length} (sum of the lengths of current and past inputs)"
+            )
+        causal_attention_mask = _get_attention_mask(
+            self, shard_config, inputs_embeds, past_key_values_length, attention_mask
+        )
+        pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
+
+        if self.project_in is not None:
+            inputs_embeds = self.project_in(inputs_embeds)
+
+        hidden_states = inputs_embeds + pos_embeds
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != (len(self.layers)):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {head_mask.size()[0]}."
+                    )
+
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, output_attentions, None)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(decoder_layer),
+                    hidden_states,
+                    causal_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    None,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=causal_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+        if self.final_layer_norm is not None:
+            hidden_states = self.final_layer_norm(hidden_states)
+
+        if self.project_out is not None:
+            hidden_states = self.project_out(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+    return forward
+
+
 def get_jit_fused_opt_decoder_layer_forward():
     from transformers.models.opt.modeling_opt import OPTDecoderLayer
 
diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py
index ab141a74a..e9c256a13 100644
--- a/colossalai/shardformer/modeling/vit.py
+++ b/colossalai/shardformer/modeling/vit.py
@@ -1,4 +1,3 @@
-import math
 from typing import List, Optional, Tuple, Union
 
 import torch
@@ -6,6 +5,7 @@ from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder
 from transformers.utils import logging
 
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
 
 
 def _encoder_forward(
@@ -98,7 +98,9 @@ def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index:
                 pixel_values = pixel_values.to(expected_dtype)
 
             embedding_output = self.embeddings(
-                pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+                pixel_values,
+                bool_masked_pos=bool_masked_pos,
+                interpolate_pos_encoding=interpolate_pos_encoding,
             )
             hidden_states = embedding_output
         else:
@@ -336,34 +338,27 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
 def get_vit_flash_self_attention_forward():
     from transformers.models.vit.modeling_vit import ViTSelfAttention
 
-    from colossalai.nn.layer.colo_attention import ColoAttention
-
-    def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor:
-        new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
-        x = x.view(new_x_shape)
-        return x
-
     def forward(
         self: ViTSelfAttention,
         hidden_states: torch.Tensor,
         head_mask: Optional[torch.Tensor] = None,
         output_attentions: bool = False,
     ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+        assert head_mask is None, "head_mask is not supported for FlashAttention"
         mixed_query_layer = self.query(hidden_states)
 
-        key_layer = transpose_for_scores(self.key(hidden_states), self.num_attention_heads, self.attention_head_size)
-        value_layer = transpose_for_scores(
-            self.value(hidden_states), self.num_attention_heads, self.attention_head_size
-        )
-        query_layer = transpose_for_scores(mixed_query_layer, self.num_attention_heads, self.attention_head_size)
+        key_layer = self.transpose_for_scores(self.key(hidden_states))
+        value_layer = self.transpose_for_scores(self.value(hidden_states))
+        query_layer = self.transpose_for_scores(mixed_query_layer)
 
-        scale = 1.0 / math.sqrt(self.attention_head_size)
-        attention = ColoAttention(
-            embed_dim=self.all_head_size, num_heads=self.num_attention_heads, dropout=self.dropout.p, scale=scale
-        )
-        context_layer = attention(query_layer, key_layer, value_layer)
+        dropout_p = self.dropout.p if self.training else 0.0
+        context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, dropout_p=dropout_p)
 
-        outputs = (context_layer,)
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+
+        outputs = (context_layer, None) if output_attentions else (context_layer,)
 
         return outputs
 
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
index cb8b45ae7..7ccc79276 100644
--- a/colossalai/shardformer/modeling/whisper.py
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -13,41 +13,74 @@ from transformers.modeling_outputs import (
     SequenceClassifierOutput,
 )
 from transformers.models.whisper.modeling_whisper import (
+    WhisperDecoder,
     WhisperEncoder,
     WhisperForAudioClassification,
     WhisperForConditionalGeneration,
     WhisperModel,
+    shift_tokens_right,
 )
 from transformers.utils import logging
 
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import ColoAttention
+from colossalai.shardformer.shard import ShardConfig
+
+logger = logging.get_logger(__name__)
+
+
+def _get_attention_mask(
+    self: WhisperDecoder,
+    shard_config: ShardConfig,
+    hidden_states: torch.Tensor,
+    past_key_values_length: int,
+    attention_mask: Optional[torch.FloatTensor],
+):
+    batch_size, seq_length = hidden_states.shape[:2]
+    mask_seq_length = past_key_values_length + seq_length
+    if shard_config.enable_flash_attention:
+        attention_mask = ColoAttention.prepare_attn_kwargs(
+            (batch_size, 1, seq_length, mask_seq_length),
+            hidden_states.dtype,
+            hidden_states.device,
+            attention_mask,
+            is_causal=True,
+        )
+    else:
+        attention_mask = self._prepare_decoder_attention_mask(
+            attention_mask,
+            (batch_size, seq_length),
+            hidden_states,
+            past_key_values_length,
+        )
+    return attention_mask
 
 
 def get_whisper_flash_attention_forward():
     from transformers.models.whisper.modeling_whisper import WhisperAttention
 
-    from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
-    def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int):
-        return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous()
-
     def forward(
         self: WhisperAttention,
         hidden_states: torch.Tensor,
         key_value_states: Optional[torch.Tensor] = None,
         past_key_value: Optional[Tuple[torch.Tensor]] = None,
-        attention_mask: Optional[torch.Tensor] = None,
+        attention_mask: Optional[dict] = None,
         layer_head_mask: Optional[torch.Tensor] = None,
         output_attentions: bool = False,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         """Input shape: Batch x Time x Channel"""
-
+        assert layer_head_mask is None, "layer_head_mask is not supported for FlashAttention"
+        # for encoder, attention_mask is None
+        if attention_mask is None:
+            attention_mask = {}
         # if key_value_states are provided this layer is used as a cross-attention layer
         # for the decoder
         is_cross_attention = key_value_states is not None
 
         bsz, tgt_len, _ = hidden_states.size()
 
+        # get query proj
+        query_states = self.q_proj(hidden_states)
         # get key, value proj
         # `past_key_value[0].shape[2] == key_value_states.shape[1]`
         # is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -55,25 +88,25 @@ def get_whisper_flash_attention_forward():
         if (
             is_cross_attention
             and past_key_value is not None
-            and past_key_value[0].shape[1] == key_value_states.shape[1]
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
         ):
             # reuse k,v, cross_attentions
             key_states = past_key_value[0]
             value_states = past_key_value[1]
         elif is_cross_attention:
             # cross_attentions
-            key_states = shape(self.k_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
-            value_states = shape(self.v_proj(key_value_states), -1, bsz, self.num_heads, self.head_dim)
+            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
         elif past_key_value is not None:
             # reuse k, v, self_attention
-            key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
-            value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
-            key_states = torch.cat([past_key_value[0], key_states], dim=1)
-            value_states = torch.cat([past_key_value[1], value_states], dim=1)
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
         else:
             # self_attention
-            key_states = shape(self.k_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
-            value_states = shape(self.v_proj(hidden_states), -1, bsz, self.num_heads, self.head_dim)
+            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
 
         if self.is_decoder:
             # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -85,38 +118,22 @@ def get_whisper_flash_attention_forward():
             # if encoder bi-directional self-attention `past_key_value` is always `None`
             past_key_value = (key_states, value_states)
 
-        # get query proj
-        query_states = shape(self.q_proj(hidden_states), tgt_len, bsz, self.num_heads, self.head_dim)
+        query_states = self._shape(query_states, tgt_len, bsz)
 
-        src_len = key_states.size(1)
-        if layer_head_mask is not None:
-            if layer_head_mask.size() != (self.num_heads,):
-                raise ValueError(
-                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
-                    f" {layer_head_mask.size()}"
-                )
-
-        attn_type = None
-        flash_attention_mask = None
-
-        if self.is_decoder:
-            if attention_mask is not None:
-                if attention_mask.size() != (bsz, 1, tgt_len, src_len):
-                    raise ValueError(
-                        f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
-                    )
-                flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
-                if not torch.all(flash_attention_mask):
-                    attn_type = AttnMaskType.paddedcausal
-                else:
-                    attn_type = AttnMaskType.causal
-
-        attention = ColoAttention(
-            embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
-        )
-        attn_output = attention(
-            query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_type
+        dropout_p = self.dropout if self.training else 0.0
+        attn_output = ColoAttention.attention(
+            query_states,
+            key_states,
+            value_states,
+            **attention_mask,
+            dropout_p=dropout_p,
+            scale=self.scaling,
         )
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
 
         attn_output = self.out_proj(attn_output)
 
@@ -125,6 +142,158 @@ def get_whisper_flash_attention_forward():
     return forward
 
 
+def get_whisper_decoder_forward_for_flash_attention(shard_config: ShardConfig):
+    def forward(
+        self: WhisperDecoder,
+        input_ids=None,
+        attention_mask=None,
+        encoder_hidden_states=None,
+        head_mask=None,
+        cross_attn_head_mask=None,
+        past_key_values=None,
+        inputs_embeds=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        attention_mask = _get_attention_mask(self, shard_config, inputs_embeds, past_key_values_length, attention_mask)
+
+        # embed positions
+        if input_ids is not None:
+            positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
+        else:
+            positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
+
+        hidden_states = inputs_embeds + positions
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+                )
+                use_cache = False
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+            if attn_mask is not None:
+                assert attn_mask.size()[0] == (len(self.layers)), (
+                    f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                    f" {head_mask.size()[0]}."
+                )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, output_attentions, use_cache)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(decoder_layer),
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    None,  # encoder attention mask
+                    head_mask[idx] if head_mask is not None else None,
+                    (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+                    None,  # past_key_value
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_cache,
+                    all_hidden_states,
+                    all_self_attns,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+    return forward
+
+
 def get_jit_fused_whisper_encoder_layer_forward():
     from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
 
@@ -292,6 +461,7 @@ class WhisperPipelineForwards:
         all_attentions=None,
         stage_index: Optional[List[int]] = None,
         decoder_starting_stage: Optional[int] = None,
+        shard_config: Optional[ShardConfig] = None,
     ):
         r"""
         Args:
@@ -403,7 +573,9 @@ class WhisperPipelineForwards:
             if not return_dict:
                 return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
             return BaseModelOutput(
-                last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+                last_hidden_state=hidden_states,
+                hidden_states=encoder_states,
+                attentions=all_attentions,
             )
 
         else:
@@ -411,7 +583,7 @@ class WhisperPipelineForwards:
 
     @staticmethod
     def whisper_decoder_forward(
-        self,
+        self: WhisperDecoder,
         input_ids=None,
         attention_mask=None,
         encoder_hidden_states=None,
@@ -427,6 +599,7 @@ class WhisperPipelineForwards:
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
         decoder_starting_stage: Optional[int] = None,
+        shard_config: Optional[ShardConfig] = None,
     ):
         r"""
         Args:
@@ -535,8 +708,12 @@ class WhisperPipelineForwards:
             else:
                 positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
 
-            attention_mask = self._prepare_decoder_attention_mask(
-                attention_mask, input_shape, inputs_embeds, past_key_values_length
+            attention_mask = _get_attention_mask(
+                self,
+                shard_config,
+                inputs_embeds,
+                past_key_values_length,
+                attention_mask,
             )
 
             hidden_states = inputs_embeds + positions
@@ -556,8 +733,12 @@ class WhisperPipelineForwards:
                 )
             input_shape = hidden_states.size()[:-1]
 
-            attention_mask = self._prepare_decoder_attention_mask(
-                attention_mask, input_shape, hidden_states, past_key_values_length
+            attention_mask = _get_attention_mask(
+                self,
+                shard_config,
+                hidden_states,
+                past_key_values_length,
+                attention_mask,
             )
 
         start_idx, end_idx = stage_index[0], stage_index[1]
@@ -590,7 +771,7 @@ class WhisperPipelineForwards:
                     encoder_hidden_states,
                     None,  # encoder attention mask
                     head_mask[idx] if head_mask is not None else None,
-                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+                    (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
                     None,  # past_key_value
                 )
             else:
@@ -626,7 +807,13 @@ class WhisperPipelineForwards:
             if not return_dict:
                 return tuple(
                     v
-                    for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+                    for v in [
+                        hidden_states,
+                        next_cache,
+                        all_hidden_states,
+                        all_self_attns,
+                        all_cross_attentions,
+                    ]
                     if v is not None
                 )
             return BaseModelOutputWithPastAndCrossAttentions(
@@ -666,6 +853,7 @@ class WhisperPipelineForwards:
         encoder_hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
         decoder_starting_stage: Optional[int] = None,
+        shard_config: Optional[ShardConfig] = None,
     ):
         r"""
         Returns:
@@ -735,7 +923,7 @@ class WhisperPipelineForwards:
             elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
                 encoder_outputs = BaseModelOutput(
                     last_hidden_state=encoder_outputs[0],
-                    hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+                    hidden_states=(encoder_outputs[1] if len(encoder_outputs) > 1 else None),
                     attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
                 )
 
@@ -767,6 +955,7 @@ class WhisperPipelineForwards:
             hidden_states=hidden_states,
             stage_index=stage_index,
             decoder_starting_stage=decoder_starting_stage,
+            shard_config=shard_config,
         )
 
         # Directly return outputs of overloaded Whisper forward if not at last stage.
@@ -810,6 +999,7 @@ class WhisperPipelineForwards:
         encoder_hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
         decoder_starting_stage: Optional[int] = None,
+        shard_config: Optional[ShardConfig] = None,
     ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
         r"""
         labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -870,6 +1060,7 @@ class WhisperPipelineForwards:
             encoder_hidden_states=encoder_hidden_states,
             stage_index=stage_index,
             decoder_starting_stage=decoder_starting_stage,
+            shard_config=shard_config,
         )
         if not in_decoder:
             return outputs
@@ -920,6 +1111,7 @@ class WhisperPipelineForwards:
         all_attentions=None,
         stage_index: Optional[List[int]] = None,
         decoder_starting_stage: Optional[int] = None,
+        shard_config: Optional[ShardConfig] = None,
     ):
         r"""
         This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index 6a50d65ba..fcf40fa39 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -8,6 +8,7 @@ import colossalai.shardformer.layer as col_nn
 from ..modeling.gpt2 import (
     GPT2PipelineForwards,
     get_gpt2_flash_attention_forward,
+    get_gpt_model_forward_for_flash_attn,
     get_lm_forward_with_dist_cross_entropy,
     gpt2_sequence_parallel_forward_fn,
 )
@@ -75,7 +76,11 @@ class GPT2Policy(Policy):
                     SubModuleReplacementDescription(
                         suffix="attn.c_attn",
                         target_module=col_nn.GPT2FusedLinearConv1D_Col,
-                        kwargs={"n_fused": 3, "seq_parallel": use_sequence_parallel, "overlap": overlap},
+                        kwargs={
+                            "n_fused": 3,
+                            "seq_parallel": use_sequence_parallel,
+                            "overlap": overlap,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attn.c_proj",
@@ -87,7 +92,11 @@ class GPT2Policy(Policy):
                     SubModuleReplacementDescription(
                         suffix="mlp.c_fc",
                         target_module=col_nn.GPT2FusedLinearConv1D_Col,
-                        kwargs={"n_fused": 1, "seq_parallel": use_sequence_parallel, "overlap": overlap},
+                        kwargs={
+                            "n_fused": 1,
+                            "seq_parallel": use_sequence_parallel,
+                            "overlap": overlap,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.c_proj",
@@ -150,6 +159,10 @@ class GPT2Policy(Policy):
                 policy=policy,
                 target_key=GPT2Attention,
             )
+            if not self.shard_config.pipeline_stage_manager:
+                policy[GPT2Model].method_replacement = {
+                    "forward": get_gpt_model_forward_for_flash_attn(self.shard_config)
+                }
 
         if self.shard_config.enable_sequence_parallelism:
             policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
@@ -223,14 +236,21 @@ class GPT2Policy(Policy):
                 num_stages=stage_manager.num_stages,
             )
             method_replacement = {
-                "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
+                "forward": partial(
+                    new_forward,
+                    stage_manager=stage_manager,
+                    shard_config=self.shard_config,
+                )
             }
         else:
             layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
             stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
             method_replacement = {
                 "forward": partial(
-                    new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
+                    new_forward,
+                    stage_manager=stage_manager,
+                    stage_index=stage_index,
+                    shard_config=self.shard_config,
                 )
             }
         self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -245,7 +265,9 @@ class GPT2ModelPolicy(GPT2Policy):
 
         if self.pipeline_stage_manager is not None:
             self.set_pipeline_forward(
-                model_cls=GPT2Model, new_forward=GPT2PipelineForwards.gpt2_model_forward, policy=policy
+                model_cls=GPT2Model,
+                new_forward=GPT2PipelineForwards.gpt2_model_forward,
+                policy=policy,
             )
         return policy
 
@@ -299,7 +321,12 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
         if stage_manager is not None:
             if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
                 first_stage, last_stage = 0, stage_manager.num_stages - 1
-                return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
+                return [
+                    {
+                        first_stage: module.transformer.wte.weight,
+                        last_stage: module.lm_head.weight,
+                    }
+                ]
         return []
 
 
@@ -315,7 +342,9 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
                 GPT2DoubleHeadsModel: ModulePolicyDescription(
                     sub_module_replacement=[
                         SubModuleReplacementDescription(
-                            suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
+                            suffix="lm_head",
+                            target_module=col_nn.Linear1D_Col,
+                            kwargs={"gather_output": True},
                         )
                     ]
                 )
@@ -350,7 +379,12 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
         if stage_manager is not None:
             if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
                 first_stage, last_stage = 0, stage_manager.num_stages - 1
-                return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
+                return [
+                    {
+                        first_stage: module.transformer.wte.weight,
+                        last_stage: module.lm_head.weight,
+                    }
+                ]
         return []
 
 
@@ -392,7 +426,10 @@ class GPT2ForTokenClassificationPolicy(GPT2Policy):
             addon_module = {
                 GPT2ForTokenClassification: ModulePolicyDescription(
                     sub_module_replacement=[
-                        SubModuleReplacementDescription(suffix="dropout", target_module=col_nn.DropoutForParallelInput)
+                        SubModuleReplacementDescription(
+                            suffix="dropout",
+                            target_module=col_nn.DropoutForParallelInput,
+                        )
                     ]
                 )
             }
diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py
index 9feb826c4..b001a2009 100644
--- a/colossalai/shardformer/policies/gptj.py
+++ b/colossalai/shardformer/policies/gptj.py
@@ -6,7 +6,11 @@ from torch import Tensor, nn
 
 import colossalai.shardformer.layer as col_nn
 
-from ..modeling.gptj import GPTJPipelineForwards, get_gptj_flash_attention_forward
+from ..modeling.gptj import (
+    GPTJPipelineForwards,
+    get_gptj_flash_attention_forward,
+    gptj_model_forward_for_flash_attention,
+)
 from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
 
 __all__ = [
@@ -71,17 +75,26 @@ class GPTJPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="attn.k_proj",
                         target_module=col_nn.Linear1D_Col,
-                        kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+                        kwargs={
+                            "seq_parallel": use_sequence_parallel,
+                            "overlap": overlap,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attn.q_proj",
                         target_module=col_nn.Linear1D_Col,
-                        kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+                        kwargs={
+                            "seq_parallel": use_sequence_parallel,
+                            "overlap": overlap,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attn.v_proj",
                         target_module=col_nn.Linear1D_Col,
-                        kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
+                        kwargs={
+                            "seq_parallel": use_sequence_parallel,
+                            "overlap": overlap,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attn.out_proj",
@@ -143,6 +156,12 @@ class GPTJPolicy(Policy):
                 policy=policy,
                 target_key=GPTJAttention,
             )
+            if not self.shard_config.pipeline_stage_manager:
+                self.append_or_create_method_replacement(
+                    description={"forward": gptj_model_forward_for_flash_attention(self.shard_config)},
+                    policy=policy,
+                    target_key=GPTJModel,
+                )
 
         return policy
 
@@ -185,7 +204,10 @@ class GPTJPolicy(Policy):
         stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
         method_replacement = {
             "forward": partial(
-                new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
+                new_forward,
+                stage_manager=stage_manager,
+                stage_index=stage_index,
+                shard_config=self.shard_config,
             )
         }
         self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -203,7 +225,9 @@ class GPTJModelPolicy(GPTJPolicy):
 
         if self.pipeline_stage_manager is not None:
             self.set_pipeline_forward(
-                model_cls=GPTJModel, new_forward=GPTJPipelineForwards.gptj_model_forward, policy=policy
+                model_cls=GPTJModel,
+                new_forward=GPTJPipelineForwards.gptj_model_forward,
+                policy=policy,
             )
         return policy
 
@@ -230,7 +254,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
                 GPTJForCausalLM: ModulePolicyDescription(
                     sub_module_replacement=[
                         SubModuleReplacementDescription(
-                            suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
+                            suffix="lm_head",
+                            target_module=col_nn.Linear1D_Col,
+                            kwargs={"gather_output": True},
                         )
                     ]
                 )
@@ -239,7 +265,9 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
 
         if self.pipeline_stage_manager is not None:
             self.set_pipeline_forward(
-                model_cls=GPTJForCausalLM, new_forward=GPTJPipelineForwards.gptj_causallm_model_forward, policy=policy
+                model_cls=GPTJForCausalLM,
+                new_forward=GPTJPipelineForwards.gptj_causallm_model_forward,
+                policy=policy,
             )
         return policy
 
@@ -256,7 +284,12 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
         if stage_manager is not None:
             if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
                 first_stage, last_stage = 0, stage_manager.num_stages - 1
-                return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
+                return [
+                    {
+                        first_stage: module.transformer.wte.weight,
+                        last_stage: module.lm_head.weight,
+                    }
+                ]
         return []
 
 
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 4c454ac7f..37c2c261b 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -11,6 +11,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Ro
 from ..modeling.llama import (
     LlamaPipelineForwards,
     get_llama_flash_attention_forward,
+    get_llama_model_forward_for_flash_attn,
     get_lm_forward_with_dist_cross_entropy,
 )
 from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -135,6 +136,15 @@ class LlamaPolicy(Policy):
                 policy=policy,
                 target_key=LlamaAttention,
             )
+            if self.pipeline_stage_manager is None:
+                # replace llama model forward method
+                self.append_or_create_method_replacement(
+                    description={
+                        "forward": get_llama_model_forward_for_flash_attn(self.shard_config),
+                    },
+                    policy=policy,
+                    target_key=LlamaModel,
+                )
 
         return policy
 
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index a542808ba..9a74da0b8 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -9,7 +9,12 @@ from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col
 
 from .._utils import getattr_
 from ..modeling.jit import get_jit_fused_dropout_add_func
-from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward
+from ..modeling.opt import (
+    OPTPipelineForwards,
+    get_jit_fused_opt_decoder_layer_forward,
+    get_opt_decoder_forward_for_flash_attention,
+    get_opt_flash_attention_forward,
+)
 from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
 
 __all__ = [
@@ -27,6 +32,7 @@ class OPTPolicy(Policy):
         import transformers
         from packaging.version import Version
 
+        # TODO: remove this version check when transformers>=4.36.0
         assert Version(transformers.__version__) <= Version(
             "4.33.0"
         ), "The OPT model should run on a transformers version not greater than 4.33.0."
@@ -111,7 +117,9 @@ class OPTPolicy(Policy):
         # optimization configuration
         self.append_or_create_submodule_replacement(
             description=SubModuleReplacementDescription(
-                suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
+                suffix="final_layer_norm",
+                target_module=norm_cls,
+                ignore_if_not_exist=True,
             ),
             policy=policy,
             target_key=OPTDecoder,
@@ -119,10 +127,14 @@ class OPTPolicy(Policy):
         self.append_or_create_submodule_replacement(
             description=[
                 SubModuleReplacementDescription(
-                    suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
+                    suffix="self_attn_layer_norm",
+                    target_module=norm_cls,
+                    ignore_if_not_exist=True,
                 ),
                 SubModuleReplacementDescription(
-                    suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
+                    suffix="final_layer_norm",
+                    target_module=norm_cls,
+                    ignore_if_not_exist=True,
                 ),
             ],
             policy=policy,
@@ -133,11 +145,19 @@ class OPTPolicy(Policy):
         if self.shard_config.enable_flash_attention:
             self.append_or_create_method_replacement(
                 description={
-                    "forward": get_opt_flash_attention_forward(),
+                    "forward": get_opt_flash_attention_forward(self.shard_config),
                 },
                 policy=policy,
                 target_key=OPTAttention,
             )
+            if not self.shard_config.pipeline_stage_manager:
+                self.append_or_create_method_replacement(
+                    description={
+                        "forward": get_opt_decoder_forward_for_flash_attention(self.shard_config),
+                    },
+                    policy=policy,
+                    target_key=OPTDecoder,
+                )
 
         # use jit fused operator
         if self.shard_config.enable_jit_fused:
@@ -190,7 +210,14 @@ class OPTPolicy(Policy):
 
             layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
             stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
-            method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+            method_replacement = {
+                "forward": partial(
+                    new_forward,
+                    stage_manager=stage_manager,
+                    stage_index=stage_index,
+                    shard_config=self.shard_config,
+                )
+            }
             self.append_or_create_method_replacement(
                 description=method_replacement, policy=policy, target_key=model_cls
             )
@@ -203,7 +230,9 @@ class OPTModelPolicy(OPTPolicy):
         policy = super().module_policy()
         if self.pipeline_stage_manager:
             self.set_pipeline_forward(
-                model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy
+                model_cls=OPTModel,
+                new_forward=OPTPipelineForwards.opt_model_forward,
+                policy=policy,
             )
         return policy
 
@@ -223,14 +252,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
         if self.shard_config.enable_tensor_parallelism:
             self.append_or_create_submodule_replacement(
                 description=SubModuleReplacementDescription(
-                    suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+                    suffix="lm_head",
+                    target_module=Linear1D_Col,
+                    kwargs=dict(gather_output=True),
                 ),
                 policy=policy,
                 target_key=OPTForCausalLM,
             )
         if self.pipeline_stage_manager:
             self.set_pipeline_forward(
-                model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy
+                model_cls=OPTForCausalLM,
+                new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
+                policy=policy,
             )
 
         return policy
@@ -246,7 +279,12 @@ class OPTForCausalLMPolicy(OPTPolicy):
         if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
             num_stages = self.pipeline_stage_manager.num_stages
             if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
-                return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
+                return [
+                    {
+                        0: opt_model.model.decoder.embed_tokens.weight,
+                        num_stages - 1: opt_model.lm_head.weight,
+                    }
+                ]
         return []
 
     def postprocess(self):
diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py
index b5b5db79d..14e1e3e0f 100644
--- a/colossalai/shardformer/policies/whisper.py
+++ b/colossalai/shardformer/policies/whisper.py
@@ -13,6 +13,7 @@ from ..modeling.whisper import (
     WhisperPipelineForwards,
     get_jit_fused_whisper_decoder_layer_forward,
     get_jit_fused_whisper_encoder_layer_forward,
+    get_whisper_decoder_forward_for_flash_attention,
     get_whisper_flash_attention_forward,
 )
 from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -31,6 +32,7 @@ class WhisperPolicy(Policy):
         import transformers
         from packaging.version import Version
 
+        # TODO: remove this version check when transformers>=4.36.0
         assert Version(transformers.__version__) <= Version(
             "4.33.0"
         ), "The Whisper model should run on a transformers version not greater than 4.33.0."
@@ -240,6 +242,14 @@ class WhisperPolicy(Policy):
                 policy=policy,
                 target_key=WhisperAttention,
             )
+            if not self.shard_config.pipeline_stage_manager:
+                self.append_or_create_method_replacement(
+                    description={
+                        "forward": get_whisper_decoder_forward_for_flash_attention(self.shard_config),
+                    },
+                    policy=policy,
+                    target_key=WhisperDecoder,
+                )
 
         # use jit fused operator
         if self.shard_config.enable_jit_fused:
@@ -269,7 +279,9 @@ class WhisperPolicy(Policy):
         if self.shard_config.enable_tensor_parallelism:
             self.append_or_create_submodule_replacement(
                 description=SubModuleReplacementDescription(
-                    suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
+                    suffix="proj_out",
+                    target_module=col_nn.Linear1D_Col,
+                    kwargs={"gather_output": True},
                 ),
                 policy=base_policy,
                 target_key=WhisperForConditionalGeneration,
@@ -326,7 +338,10 @@ class WhisperPolicy(Policy):
         if stage < decoder_starting_stage:
             return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
         else:
-            return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
+            return Policy.get_stage_index(
+                layers_per_stage[decoder_starting_stage:],
+                stage - decoder_starting_stage,
+            )
 
     def get_held_layers(self) -> List[nn.Module]:
         assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
@@ -422,6 +437,7 @@ class WhisperPolicy(Policy):
                 stage_manager=stage_manager,
                 stage_index=stage_index,
                 decoder_starting_stage=decoder_starting_stage,
+                shard_config=self.shard_config,
             )
         }
         self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@@ -436,7 +452,9 @@ class WhisperModelPolicy(WhisperPolicy):
 
         if self.pipeline_stage_manager is not None:
             self.set_pipeline_forward(
-                model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy
+                model_cls=WhisperModel,
+                new_forward=WhisperPipelineForwards.whisper_model_forward,
+                policy=policy,
             )
 
         return policy
diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py
index 4f2a4878e..e415b5fc3 100644
--- a/colossalai/testing/comparison.py
+++ b/colossalai/testing/comparison.py
@@ -40,7 +40,12 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
         assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"
 
 
-def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False):
+def check_state_dict_equal(
+    d1: OrderedDict,
+    d2: OrderedDict,
+    ignore_device: bool = True,
+    ignore_dtype: bool = False,
+):
     assert len(list(d1.keys())) == len(
         list(d2.keys())
     ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
@@ -94,7 +99,12 @@ def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_devic
 
 
 def assert_hf_output_close(
-    out1: Any, out2: Any, ignore_keys: List[str] = None, track_name: str = "", atol=1e-5, rtol=1e-5
+    out1: Any,
+    out2: Any,
+    ignore_keys: List[str] = None,
+    track_name: str = "",
+    atol=1e-5,
+    rtol=1e-5,
 ):
     """
     Check if two outputs from huggingface are equal.
@@ -113,7 +123,12 @@ def assert_hf_output_close(
             if ignore_keys is not None and k in ignore_keys:
                 continue
             assert_hf_output_close(
-                out1[k], out2[k], track_name=f"{track_name}.{k}", ignore_keys=ignore_keys, atol=atol, rtol=rtol
+                out1[k],
+                out2[k],
+                track_name=f"{track_name}.{k}",
+                ignore_keys=ignore_keys,
+                atol=atol,
+                rtol=rtol,
             )
     elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)):
         # if two values are list
@@ -121,12 +136,17 @@ def assert_hf_output_close(
         assert len(out1) == len(out2)
         for i in range(len(out1)):
             assert_hf_output_close(
-                out1[i], out2[i], track_name=f"{track_name}.{i}", ignore_keys=ignore_keys, atol=atol, rtol=rtol
+                out1[i],
+                out2[i],
+                track_name=f"{track_name}.{i}",
+                ignore_keys=ignore_keys,
+                atol=atol,
+                rtol=rtol,
             )
     elif isinstance(out1, Tensor) and isinstance(out2, Tensor):
         if out1.shape != out2.shape:
             raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
-        assert torch.allclose(
+        assert_close(
             out1, out2, atol=atol, rtol=rtol
         ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
     else:
diff --git a/extensions/README.md b/extensions/README.md
index 6f5feb55c..b9bde7742 100644
--- a/extensions/README.md
+++ b/extensions/README.md
@@ -101,13 +101,13 @@ class MyExtension(_Extension):
         self._support_jit = True
         self.priority = 10
 
-    def is_hardware_available(self) -> bool:
+    def is_available(self) -> bool:
         """
         Return if the required hardware can be found.
         """
         ...
 
-    def assert_hardware_compatible(self) -> None:
+    def assert_compatible(self) -> None:
         """
         Check if the hardware required by the kernel is compatible.
         """
diff --git a/extensions/__init__.py b/extensions/__init__.py
index 9343cadda..0dbadba81 100644
--- a/extensions/__init__.py
+++ b/extensions/__init__.py
@@ -1,9 +1,5 @@
 from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
-from .flash_attention import (
-    FlashAttentionDaoCudaExtension,
-    FlashAttentionNpuExtension,
-    FlashAttentionXformersCudaExtension,
-)
+from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension
 from .layernorm import LayerNormCudaExtension
 from .moe import MoeCudaExtension
 from .optimizer import FusedOptimizerCudaExtension
@@ -18,7 +14,7 @@ ALL_EXTENSIONS = [
     ScaledMaskedSoftmaxCudaExtension,
     ScaledUpperTriangleMaskedSoftmaxCudaExtension,
     FlashAttentionDaoCudaExtension,
-    FlashAttentionXformersCudaExtension,
+    FlashAttentionSdpaCudaExtension,
     FlashAttentionNpuExtension,
 ]
 
@@ -31,6 +27,6 @@ __all__ = [
     "ScaledMaskedSoftmaxCudaExtension",
     "ScaledUpperTriangleMaskedSoftmaxCudaExtension",
     "FlashAttentionDaoCudaExtension",
-    "FlashAttentionXformersCudaExtension",
+    "FlashAttentionSdpaCudaExtension",
     "FlashAttentionNpuExtension",
 ]
diff --git a/extensions/base_extension.py b/extensions/base_extension.py
index c815a7f2a..0c79c0a9e 100644
--- a/extensions/base_extension.py
+++ b/extensions/base_extension.py
@@ -58,13 +58,13 @@ class _Extension(ABC):
         return cache_directory
 
     @abstractmethod
-    def is_hardware_available(self) -> bool:
+    def is_available(self) -> bool:
         """
         Check if the hardware required by the kernel is available.
         """
 
     @abstractmethod
-    def assert_hardware_compatible(self) -> None:
+    def assert_compatible(self) -> None:
         """
         Check if the hardware required by the kernel is compatible.
         """
diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/cpu_adam/cpu_adam_arm.py
index 35bff3b55..61c4f3ed0 100644
--- a/extensions/cpu_adam/cpu_adam_arm.py
+++ b/extensions/cpu_adam/cpu_adam_arm.py
@@ -7,11 +7,11 @@ class CpuAdamArmExtension(_CppExtension):
     def __init__(self):
         super().__init__(name="cpu_adam_arm")
 
-    def is_hardware_available(self) -> bool:
+    def is_available(self) -> bool:
         # only arm allowed
         return platform.machine() == "aarch64"
 
-    def assert_hardware_compatible(self) -> None:
+    def assert_compatible(self) -> None:
         arch = platform.machine()
         assert (
             arch == "aarch64"
diff --git a/extensions/cpu_adam/cpu_adam_x86.py b/extensions/cpu_adam/cpu_adam_x86.py
index a38194167..9bbc8d851 100644
--- a/extensions/cpu_adam/cpu_adam_x86.py
+++ b/extensions/cpu_adam/cpu_adam_x86.py
@@ -8,15 +8,15 @@ class CpuAdamX86Extension(_CudaExtension):
     def __init__(self):
         super().__init__(name="cpu_adam_x86")
 
-    def is_hardware_available(self) -> bool:
-        return platform.machine() == "x86_64" and super().is_hardware_available()
+    def is_available(self) -> bool:
+        return platform.machine() == "x86_64" and super().is_available()
 
-    def assert_hardware_compatible(self) -> None:
+    def assert_compatible(self) -> None:
         arch = platform.machine()
         assert (
             arch == "x86_64"
         ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}"
-        super().assert_hardware_compatible()
+        super().assert_compatible()
 
     # necessary 4 functions
     def sources_files(self):
diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py
index 842cd9713..f1e0095b2 100644
--- a/extensions/cuda_extension.py
+++ b/extensions/cuda_extension.py
@@ -22,7 +22,7 @@ class _CudaExtension(_CppExtension):
         This function should return a list of nvcc compilation flags for extensions.
         """
 
-    def is_hardware_available(self) -> bool:
+    def is_available(self) -> bool:
         # cuda extension can only be built if cuda is available
         try:
             import torch
@@ -32,7 +32,7 @@ class _CudaExtension(_CppExtension):
             cuda_available = False
         return cuda_available
 
-    def assert_hardware_compatible(self) -> None:
+    def assert_compatible(self) -> None:
         from torch.utils.cpp_extension import CUDA_HOME
 
         if not CUDA_HOME:
diff --git a/extensions/flash_attention/__init__.py b/extensions/flash_attention/__init__.py
index 18abb6191..ea5b442aa 100644
--- a/extensions/flash_attention/__init__.py
+++ b/extensions/flash_attention/__init__.py
@@ -1,20 +1,14 @@
 from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension
 from .flash_attention_npu import FlashAttentionNpuExtension
-from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension
+from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension
 
 try:
+    # TODO: remove this after updating openmoe example
     import flash_attention  # noqa
 
     HAS_FLASH_ATTN = True
 except:
     HAS_FLASH_ATTN = False
 
-try:
-    import xformers  # noqa
 
-    HAS_MEM_EFF_ATTN = True
-except:
-    HAS_MEM_EFF_ATTN = False
-
-
-__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"]
+__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"]
diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/flash_attention/flash_attention_dao_cuda.py
index 1b7f8ac47..a2f2a52f1 100644
--- a/extensions/flash_attention/flash_attention_dao_cuda.py
+++ b/extensions/flash_attention/flash_attention_dao_cuda.py
@@ -5,17 +5,20 @@ class FlashAttentionDaoCudaExtension(_Extension):
     def __init__(self):
         super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10)
 
-    def is_hardware_available(self) -> bool:
+    def is_available(self) -> bool:
         # cuda extension can only be built if cuda is available
         try:
             import torch
 
+            from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func  # noqa
+            from flash_attn.bert_padding import index_first_axis, pad_input  # noqa
+
             cuda_available = torch.cuda.is_available()
         except:
             cuda_available = False
         return cuda_available
 
-    def assert_hardware_compatible(self) -> bool:
+    def assert_compatible(self) -> bool:
         pass
 
     def build_aot(self) -> None:
@@ -29,65 +32,65 @@ class FlashAttentionDaoCudaExtension(_Extension):
         )
 
     def load(self):
-        try:
-            from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
-        except ImportError:
-            raise ModuleNotFoundError(
-                (
-                    "We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
-                )
-            )
-
         from typing import Optional
 
         import torch
+        from einops import rearrange
+        from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func
+        from flash_attn.bert_padding import index_first_axis, pad_input
+
+        def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
+            return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
 
         def flash_attention(
             q: torch.Tensor,
             k: torch.Tensor,
             v: torch.Tensor,
-            seq_len_info_q: "SeqLenInfo",
-            seq_len_info_kv: "SeqLenInfo",
-            origin_attn_mask: Optional[torch.Tensor] = None,
-            bias: Optional[torch.Tensor] = None,
             dropout_p: float = 0.0,
-            scale: float = None,
-            causal: bool = False,
-            padded: bool = False,
+            scale: Optional[float] = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            is_causal: bool = False,
+            cu_seqlens_q: Optional[torch.Tensor] = None,
+            cu_seqlens_kv: Optional[torch.Tensor] = None,
+            max_seqlen_q: Optional[int] = None,
+            max_seqlen_kv: Optional[int] = None,
+            q_indices: Optional[torch.Tensor] = None,
+            kv_indices: Optional[torch.Tensor] = None,
         ):
-            """
-            Arguments:
-                q: (batch, q_seqlen, nheads, headdim)
-                k: (batch, kv_seqlen, nheads, headdim)
-                v: (batch, kv_seqlen, nheads, headdim)
-                batch_size: int.
-                seq_len: int.
-                dropout_p: float. Dropout probability.
-                sm_scale: float. The scaling of QK^T before applying softmax.
-                    Default to 1 / sqrt(headdim).
-                causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
-            Return:
-                attn_out: (batch, q_seqlen, nheads, headdim).
-            """
-            # check if the input is in allowed dtypes
-            if padded:
-                if seq_len_info_kv == None:
-                    seq_len_info_kv = seq_len_info_q
-
-                attn_out = flash_attn_varlen_func(
+            # [B, N, S, D] -> [B, S, N, D]
+            q = q.transpose(1, 2)
+            k = k.transpose(1, 2)
+            v = v.transpose(1, 2)
+            b, s_q = q.shape[:2]
+            if cu_seqlens_q is not None:
+                # padded / padded causal
+                # unpad input: [B, S, N, D] -> [T, N, D]
+                q = _unpad_input(q, q_indices)
+                kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
+                attn_output = flash_attn_varlen_kvpacked_func(
+                    q,
+                    kv,
+                    cu_seqlens_q,
+                    cu_seqlens_kv,
+                    max_seqlen_q,
+                    max_seqlen_kv,
+                    dropout_p=dropout_p,
+                    softmax_scale=scale,
+                    causal=is_causal,
+                )
+                # pad output: [T, N, D] -> [B, S, N, D]
+                attn_output = pad_input(attn_output, q_indices, b, s_q)
+            else:
+                # causal / no attn mask
+                attn_output = flash_attn_func(
                     q,
                     k,
                     v,
-                    seq_len_info_q.cu_seqlens,
-                    seq_len_info_kv.cu_seqlens,
-                    seq_len_info_q.max_seqlen,
-                    seq_len_info_kv.max_seqlen,
-                    dropout_p,
-                    scale,
-                    causal,
+                    dropout_p=dropout_p,
+                    softmax_scale=scale,
+                    causal=is_causal,
                 )
-            else:
-                attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal)
-            return attn_out
+            # [B, S, N, D] -> [B, N, S, D]
+            return attn_output.transpose(1, 2)
 
         return flash_attention
diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/flash_attention/flash_attention_npu.py
index 58d0f9306..0e01cefa1 100644
--- a/extensions/flash_attention/flash_attention_npu.py
+++ b/extensions/flash_attention/flash_attention_npu.py
@@ -5,15 +5,15 @@ class FlashAttentionNpuExtension(_Extension):
     def __init__(self):
         super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)
 
-    def is_hardware_available(self) -> bool:
+    def is_available(self) -> bool:
         try:
-            import torch_npu  # noqa
+            import torch_npu
 
-            return True
+            return hasattr(torch_npu, "npu_fusion_attention")
         except:
             return False
 
-    def assert_hardware_compatible(self) -> bool:
+    def assert_compatible(self) -> bool:
         pass
 
     def build_aot(self) -> None:
@@ -27,47 +27,36 @@ class FlashAttentionNpuExtension(_Extension):
         )
 
     def load(self):
-        import torch
-        from einops import rearrange
+        from typing import Optional
 
-        def npu_sdpa_attention(
+        import torch
+        import torch_npu
+
+        def flash_attention(
             q: torch.Tensor,
             k: torch.Tensor,
             v: torch.Tensor,
-            seq_len_info_q=None,
-            seq_len_info_kv=None,
-            origin_attn_mask: torch.Tensor = None,
             dropout_p: float = 0.0,
-            scale: float = 1.0,
-            causal=None,
-            padded=None,
+            scale: Optional[float] = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            is_causal: bool = False,
+            cu_seqlens_q: Optional[torch.Tensor] = None,
+            cu_seqlens_kv: Optional[torch.Tensor] = None,
+            max_seqlen_q: Optional[int] = None,
+            max_seqlen_kv: Optional[int] = None,
+            q_indices: Optional[torch.Tensor] = None,
+            kv_indices: Optional[torch.Tensor] = None,
         ):
-            """
-            The scaled dot product attention.
-
-            Arguments:
-                q: (batch, q_seqlen, nheads, headdim)
-                k: (batch, kv_seqlen, nheads, headdim)
-                v: (batch, kv_seqlen, nheads, headdim)
-                batch_size: int.
-                seq_len: int.
-                dropout_p: float. Dropout probability.
-                scale: float. The scaling of QK^T before applying softmax.
-                    Default to 1.
-            Return:
-                attn_out: (batch, q_seqlen, nheads, headdim).
-            """
-            q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
-            output = torch.nn.functional.scaled_dot_product_attention(
+            num_heads = q.size(1)
+            return torch_npu.npu_fusion_attention(
                 q,
                 k,
                 v,
-                attn_mask=origin_attn_mask,
-                dropout_p=dropout_p,
-                is_causal=origin_attn_mask is None,
+                num_heads,
+                "BNSD",
+                atten_mask=attention_mask.bool(),
                 scale=scale,
-            )
-            output = rearrange(output, "b h s d -> b s (h d)")
-            return output
+                keep_prob=1 - dropout_p,
+            )[0]
 
-        return npu_sdpa_attention
+        return flash_attention
diff --git a/extensions/flash_attention/flash_attention_sdpa_cuda.py b/extensions/flash_attention/flash_attention_sdpa_cuda.py
new file mode 100644
index 000000000..d3323a6aa
--- /dev/null
+++ b/extensions/flash_attention/flash_attention_sdpa_cuda.py
@@ -0,0 +1,56 @@
+from ..base_extension import _Extension
+
+
+class FlashAttentionSdpaCudaExtension(_Extension):
+    def __init__(self):
+        super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False)
+
+    def is_available(self) -> bool:
+        # cuda extension can only be built if cuda is available
+        try:
+            import torch
+
+            cuda_available = torch.cuda.is_available()
+        except:
+            cuda_available = False
+        return cuda_available
+
+    def assert_compatible(self) -> bool:
+        pass
+
+    def build_aot(self) -> None:
+        raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.")
+
+    def build_jit(self) -> None:
+        raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.")
+
+    def load(self):
+        from typing import Optional
+
+        import torch
+
+        def flash_attention(
+            q: torch.Tensor,
+            k: torch.Tensor,
+            v: torch.Tensor,
+            dropout_p: float = 0.0,
+            scale: Optional[float] = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            is_causal: bool = False,
+            cu_seqlens_q: Optional[torch.Tensor] = None,
+            cu_seqlens_kv: Optional[torch.Tensor] = None,
+            max_seqlen_q: Optional[int] = None,
+            max_seqlen_kv: Optional[int] = None,
+            q_indices: Optional[torch.Tensor] = None,
+            kv_indices: Optional[torch.Tensor] = None,
+        ):
+            return torch.nn.functional.scaled_dot_product_attention(
+                q,
+                k,
+                v,
+                attn_mask=attention_mask,
+                dropout_p=dropout_p,
+                scale=scale,
+            )
+
+        return flash_attention
diff --git a/extensions/flash_attention/flash_attention_xformers_cuda.py b/extensions/flash_attention/flash_attention_xformers_cuda.py
deleted file mode 100644
index 27cd823de..000000000
--- a/extensions/flash_attention/flash_attention_xformers_cuda.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from ..base_extension import _Extension
-
-
-class FlashAttentionXformersCudaExtension(_Extension):
-    def __init__(self):
-        super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
-
-    def is_hardware_available(self) -> bool:
-        # cuda extension can only be built if cuda is available
-        try:
-            import torch
-
-            cuda_available = torch.cuda.is_available()
-        except:
-            cuda_available = False
-        return cuda_available
-
-    def assert_hardware_compatible(self) -> bool:
-        pass
-
-    def build_aot(self) -> None:
-        raise NotImplementedError(
-            "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
-        )
-
-    def build_jit(self) -> None:
-        raise NotImplementedError(
-            "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
-        )
-
-    def load(self):
-        try:
-            from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
-            from xformers.ops.fmha.attn_bias import (
-                BlockDiagonalCausalMask,
-                BlockDiagonalMask,
-                LowerTriangularMask,
-                LowerTriangularMaskWithTensorBias,
-            )
-        except ImportError:
-            raise ModuleNotFoundError(
-                (
-                    "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
-                )
-            )
-        from typing import Optional
-
-        import torch
-
-        allow_alibi = True
-        for op in MemoryEfficientAttentionCutlassOp:
-            allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
-
-        def mem_eff_attention(
-            q: torch.Tensor,
-            k: torch.Tensor,
-            v: torch.Tensor,
-            seq_len_info_q: "SeqLenInfo",
-            seq_len_info_kv: "SeqLenInfo",
-            origin_attn_mask: Optional[torch.Tensor] = None,
-            bias: Optional[torch.Tensor] = None,
-            dropout_p: float = 0.0,
-            scale: float = None,
-            causal: bool = False,
-            padded: bool = False,
-        ):
-            attn_bias = None
-            if padded:  # bert style
-                if not causal:
-                    attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
-                else:
-                    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
-            elif causal:  # gpt style
-                attn_bias = LowerTriangularMask()
-
-            if bias is not None:  # alibi / relative position embedding
-                assert allow_alibi, "flash attention with bias is not supported in this system."
-                assert causal, "attention with bias is only supported for causal attention so far."
-                attn_bias = attn_bias.add_bias(bias)
-
-            if padded:
-                q = q.unsqueeze(0)
-                k = k.unsqueeze(0)
-                v = v.unsqueeze(0)
-
-            out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
-
-            # shape: (b*s, n, d)
-            if padded:
-                out = out.squeeze(0)
-
-            return out
-
-        return mem_eff_attention
diff --git a/setup.py b/setup.py
index ef89481e6..c16709ad1 100644
--- a/setup.py
+++ b/setup.py
@@ -80,8 +80,8 @@ if BUILD_EXT:
 
     for ext_cls in ALL_EXTENSIONS:
         ext = ext_cls()
-        if ext.support_aot and ext.is_hardware_available():
-            ext.assert_hardware_compatible()
+        if ext.support_aot and ext.is_available():
+            ext.assert_compatible()
             op_names.append(ext.name)
             ext_modules.append(ext.build_aot())
 
diff --git a/tests/test_shardformer/test_flash_attention.py b/tests/test_shardformer/test_flash_attention.py
new file mode 100644
index 000000000..f9eab132f
--- /dev/null
+++ b/tests/test_shardformer/test_flash_attention.py
@@ -0,0 +1,147 @@
+import math
+from copy import copy
+
+import torch
+from torch.testing import assert_close
+
+from colossalai.kernel.kernel_loader import (
+    FlashAttentionLoader,
+    FlashAttentionWithCustomMaskLoader,
+    FlashAttentionWithPaddingMaskLoader,
+)
+from colossalai.shardformer.layer import AttnMaskType, ColoAttention
+from colossalai.shardformer.layer.attn import invert_mask
+from colossalai.testing import clear_cache_before_run, parameterize
+from colossalai.utils import get_current_device, set_seed
+
+DTYPE = [torch.float16, torch.bfloat16]
+B, N, S, D = 2, 8, 256, 32
+
+TOL_MAP = {
+    torch.float16: {"atol": 5e-4, "rtol": 2e-3},
+    torch.bfloat16: {},
+}
+
+
+def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0):
+    head_dim = q.size(-1)
+    attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
+    if attn_mask is not None:
+        attn_weights = attn_weights + attn_mask
+    attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype)
+    attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True)
+    attn_output = torch.matmul(attn_weights, v)
+    return attn_output
+
+
+def gen_padded_kwargs(dtype: torch.dtype):
+    padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device())
+    padding_mask[0, : S // 4] = 0
+    return (
+        ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask),
+        padding_mask,
+    )
+
+
+def gen_padded_causal_kwargs(dtype: torch.dtype):
+    padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device())
+    padding_mask[0, S // 2 :] = 0
+    return (
+        ColoAttention.prepare_attn_kwargs(
+            (B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True
+        ),
+        padding_mask,
+    )
+
+
+def gen_causal_kwargs(dtype: torch.dtype):
+    return ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, get_current_device(), is_causal=True), None
+
+
+def gen_custom_kwargs(dtype: torch.dtype):
+    attn_mask = torch.ones((B, S, S), dtype=dtype, device=get_current_device())
+    attn_mask[0, : S // 2, S // 2 :] = 0
+    attn_mask[0, S // 2 :, : S // 2] = 0
+    attn_mask[1, :, S // 4 :] = 0
+    attn_mask = invert_mask(attn_mask).unsqueeze(1)
+    assert not torch.all(attn_mask != 0, dim=-1).any()
+    return {"attention_mask": attn_mask}, None
+
+
+def post_process_kwargs_for_raw_attn(attn_kwargs: dict):
+    if "attention_mask_type" in attn_kwargs:
+        attn_kwargs = copy(attn_kwargs)
+        mask_type = attn_kwargs.pop("attention_mask_type")
+        attn_kwargs["is_causal"] = mask_type in (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
+    return attn_kwargs
+
+
+def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_mask=None):
+    tols = TOL_MAP[dtype]
+    q = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
+    k = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
+    v = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
+    q_flash = q.clone().detach().requires_grad_(True)
+    k_flash = k.clone().detach().requires_grad_(True)
+    v_flash = v.clone().detach().requires_grad_(True)
+    attn_mask = attn_kwargs.get("attention_mask", None)
+    ref_output = attention_ref(q, k, v, attn_mask)
+    output = attn_func(q_flash, k_flash, v_flash, **attn_kwargs)
+    if padding_mask is not None:
+        # [B, Sq] -> [B, 1, Sq, 1]
+        padding_mask = padding_mask[:, None, :, None].logical_not()
+        ref_output = ref_output.masked_fill(padding_mask, 0)
+        output = output.masked_fill(padding_mask, 0)
+    assert_close(output, ref_output, **tols)
+    output.mean().backward()
+    ref_output.mean().backward()
+    assert_close(q.grad, q_flash.grad, **tols)
+    assert_close(k.grad, k_flash.grad, **tols)
+    assert_close(v.grad, v_flash.grad, **tols)
+
+
+@clear_cache_before_run()
+@parameterize("dtype", DTYPE)
+def test_flash_attn_func(dtype: torch.dtype):
+    torch.backends.cudnn.deterministic = True
+    set_seed(0)
+    # (func, name, need_postprocess)
+    avail_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
+    avail_custom_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
+    avail_padding_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
+    for ext_cls in FlashAttentionLoader.REGISTRY:
+        ext = ext_cls()
+        if ext.is_available():
+            ext.assert_compatible()
+            avail_attn_funcs.append((ext.load(), ext.name, True))
+    for ext_cls in FlashAttentionWithCustomMaskLoader.REGISTRY:
+        ext = ext_cls()
+        if ext.is_available():
+            ext.assert_compatible()
+            avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
+    for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY:
+        ext = ext_cls()
+        if ext.is_available():
+            ext.assert_compatible()
+            avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True))
+
+    test_sets = {
+        "none": (lambda dtype: ({}, None), avail_attn_funcs),
+        "padded": (gen_padded_kwargs, avail_padding_mask_attn_funcs),
+        "padded_causal": (gen_padded_causal_kwargs, avail_padding_mask_attn_funcs),
+        "causal": (gen_causal_kwargs, avail_attn_funcs),
+        "custom": (gen_custom_kwargs, avail_custom_mask_attn_funcs),
+    }
+
+    for mask_type, (gen_kwargs_func, attn_funcs) in test_sets.items():
+        attn_kwargs, padding_mask = gen_kwargs_func(dtype)
+        for attn_func, name, need_postprocess in attn_funcs:
+            print(f"{dtype}, {name}, {mask_type}")
+            if need_postprocess:
+                check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
+            else:
+                check_attn_func(dtype, attn_func, attn_kwargs, padding_mask)
+
+
+if __name__ == "__main__":
+    test_flash_attn_func()
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index 62d4d1bf3..85be9a242 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -31,6 +31,7 @@ def build_model(
     enable_jit_fused=False,
     enable_sequence_parallelism=False,
     use_lazy_init: bool = False,
+    dtype=torch.float32,
 ):
     # create new model
     ctx = LazyInitContext() if use_lazy_init else nullcontext()
@@ -51,7 +52,7 @@ def build_model(
     model_copy = copy.deepcopy(org_model)
     shard_former = ShardFormer(shard_config=shard_config)
     sharded_model, shared_params = shard_former.optimize(model_copy)
-    return org_model.cuda(), sharded_model.cuda()
+    return org_model.cuda().to(dtype), sharded_model.cuda().to(dtype)
 
 
 def build_pipeline_model(
@@ -132,7 +133,14 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
     booster = Booster(plugin=plugin)
 
     sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
-    return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster
+    return (
+        org_model,
+        org_optimizer,
+        sharded_model,
+        sharded_optimizer,
+        criterion,
+        booster,
+    )
 
 
 def run_forward_backward_with_hybrid_plugin(
@@ -173,7 +181,12 @@ def run_forward_backward_with_hybrid_plugin(
 
         data_iter = iter([data])
         sharded_output = booster.execute_pipeline(
-            data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True
+            data_iter,
+            sharded_model,
+            _criterion,
+            sharded_optimizer,
+            return_loss=True,
+            return_outputs=True,
         )
         sharded_loss = sharded_output["loss"]
     else:
@@ -313,7 +326,9 @@ def check_grad(
 
 
 def unwrap_model(
-    module: Module, base_model_class_name: Optional[str] = None, base_model_attribute_name: Optional[str] = None
+    module: Module,
+    base_model_class_name: Optional[str] = None,
+    base_model_attribute_name: Optional[str] = None,
 ):
     if isinstance(module, HybridParallelModule):
         module = module.unwrap()
diff --git a/tests/test_shardformer/test_model/test_shard_blip2.py b/tests/test_shardformer/test_model/test_shard_blip2.py
index 02c15460e..2c56b0435 100644
--- a/tests/test_shardformer/test_model/test_shard_blip2.py
+++ b/tests/test_shardformer/test_model/test_shard_blip2.py
@@ -45,19 +45,51 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
         "qformer.encoder.layer[0].attention.output.dense",
         "language_model.model.decoder.layers[0].self_attn.out_proj",
     ]
-    check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
-    check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
+    check_grad(
+        blip2,
+        sharded_blip2,
+        col_layer_for_check,
+        atol=1e-6,
+        rtol=1e-5,
+        dim=0,
+        verbose=False,
+    )
+    check_grad(
+        blip2,
+        sharded_blip2,
+        row_layer_for_check,
+        atol=1e-6,
+        rtol=1e-5,
+        dim=1,
+        verbose=False,
+    )
 
 
 @parameterize("enable_fused_normalization", [True, False])
 @parameterize("enable_tensor_parallelism", [True, False])
 @parameterize("enable_flash_attention", [True, False])
 @parameterize("enable_jit_fused", [True, False])
-def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
+def run_blip2_test(
+    enable_fused_normalization,
+    enable_tensor_parallelism,
+    enable_flash_attention,
+    enable_jit_fused,
+):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_blip2")
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         org_model, sharded_model = build_model(
-            model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused
+            model_fn,
+            enable_fused_normalization,
+            enable_tensor_parallelism,
+            enable_flash_attention,
+            enable_jit_fused,
+            dtype=torch.float,
         )
         check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
 
@@ -66,7 +98,14 @@ def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable
 
 def check_blip2(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_blip2_test()
 
 
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py
index 29d3592bf..78d752b69 100644
--- a/tests/test_shardformer/test_model/test_shard_chatglm2.py
+++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py
@@ -11,7 +11,6 @@ from tests.test_shardformer.test_model._utils import (
     build_model_from_hybrid_plugin,
     check_all_grad_tensors,
     check_loss,
-    check_output_hidden_state,
     check_weight,
     get_grad_tensors_for_check,
     run_forward_backward_with_hybrid_plugin,
@@ -25,7 +24,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
     )
 
     org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
-        org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+        org_model,
+        sharded_model,
+        sharded_optimizer,
+        data_gen_fn,
+        output_transform_fn,
+        criterion,
+        booster,
     )
 
     stage_manager = booster.plugin.stage_manager
@@ -36,7 +41,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
     shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer")
 
     norm_layer_for_check = ["encoder.layers[0].input_layernorm"]
-    row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"]
+    row_layer_for_check = [
+        "encoder.layers[0].self_attention.query_key_value",
+        "embedding.word_embeddings",
+    ]
     col_layer_for_check = ["encoder.layers[0].self_attention.dense"]
 
     # Save gradient tensors for comparison between the original model and the sharded model.
@@ -94,8 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         else:
             atol, rtol = 5e-3, 5e-3
 
-        if org_model.__class__.__name__ == "ChatGLMModel":
-            check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
+        # TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
+        # if org_model.__class__.__name__ == "ChatGLMModel":
+        #     check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
 
         check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
 
@@ -143,8 +152,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             "use_lazy_init": False,
             "precision": "fp32",
         },
-        {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
-        {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
+        {
+            "tp_size": 4,
+            "pp_size": 1,
+            "enable_all_optimization": True,
+            "use_lazy_init": False,
+            "precision": "fp32",
+        },
+        {
+            "tp_size": 2,
+            "pp_size": 1,
+            "enable_all_optimization": True,
+            "use_lazy_init": False,
+            "precision": "fp32",
+        },
         {
             "tp_size": 2,
             "pp_size": 1,
@@ -159,7 +180,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
 def run_chatglm_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -193,7 +220,13 @@ def run_chatglm_test(test_config):
 def run_chatglm_3d_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -202,13 +235,27 @@ def run_chatglm_3d_test(test_config):
 
 def check_chatglm(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_chatglm_test()
 
 
 def check_chatglm_3d(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_chatglm_3d_test()
 
 
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index 3155420f1..d59d7e4ad 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
     )
 
     org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
-        org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+        org_model,
+        sharded_model,
+        sharded_optimizer,
+        data_gen_fn,
+        output_transform_fn,
+        criterion,
+        booster,
     )
 
     stage_manager = booster.plugin.stage_manager
@@ -47,10 +53,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         else:
             atol, rtol = 5e-3, 5e-3
         col_layer_grads = get_grad_tensors_for_check(
-            gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
+            gpt2,
+            sharded_gpt2,
+            col_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=1,
+            verbose=False,
         )
         row_layer_grads = get_grad_tensors_for_check(
-            gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
+            gpt2,
+            sharded_gpt2,
+            row_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=0,
+            verbose=False,
         )
 
         norm_layer_grads = get_grad_tensors_for_check(
@@ -90,7 +110,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             atol, rtol = 5e-3, 1e-3
         else:
             atol, rtol = 5e-3, 5e-3
-        check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
+        check_weight(
+            gpt2,
+            sharded_gpt2,
+            col_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=1,
+            verbose=False,
+        )
 
     # check grads
     check_all_grad_tensors(grads_to_check)
@@ -123,14 +152,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         {
             "tp_size": 4,
             "pp_size": 1,
-            "enable_all_optimization": True,
+            "enable_all_optimization": False,
             "use_lazy_init": False,
             "precision": "fp32",
         },
         {
             "tp_size": 2,
             "pp_size": 1,
-            "enable_all_optimization": True,
+            "enable_all_optimization": False,
             "use_lazy_init": False,
             "precision": "fp32",
         },
@@ -138,7 +167,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             "tp_size": 2,
             "pp_size": 2,
             "num_microbatches": 4,
-            "enable_all_optimization": True,
+            "enable_all_optimization": False,
             "use_lazy_init": True,
             "precision": "fp32",
         },
@@ -167,7 +196,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
 def run_gpt2_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -202,7 +237,13 @@ def run_gpt2_test(test_config):
 def run_gpt2_3d_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -211,13 +252,27 @@ def run_gpt2_3d_test(test_config):
 
 def check_gpt2(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_gpt2_test()
 
 
 def check_gpt2_3d(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_gpt2_3d_test()
 
 
diff --git a/tests/test_shardformer/test_model/test_shard_gptj.py b/tests/test_shardformer/test_model/test_shard_gptj.py
index c83eaaa09..009202a0d 100644
--- a/tests/test_shardformer/test_model/test_shard_gptj.py
+++ b/tests/test_shardformer/test_model/test_shard_gptj.py
@@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
     )
 
     org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
-        org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+        org_model,
+        sharded_model,
+        sharded_optimizer,
+        data_gen_fn,
+        output_transform_fn,
+        criterion,
+        booster,
     )
 
     stage_manager = booster.plugin.stage_manager
@@ -46,11 +52,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         else:
             atol, rtol = 5e-3, 5e-3
         col_layer_grads = get_grad_tensors_for_check(
-            gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
+            gptj,
+            sharded_gptj,
+            col_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=0,
+            verbose=False,
         )
 
         row_layer_grads = get_grad_tensors_for_check(
-            gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
+            gptj,
+            sharded_gptj,
+            row_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=1,
+            verbose=False,
         )
         grads_to_check.update(col_layer_grads)
         grads_to_check.update(row_layer_grads)
@@ -77,7 +97,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             atol, rtol = 5e-3, 1e-3
         else:
             atol, rtol = 5e-3, 5e-3
-        check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
+        check_weight(
+            gptj,
+            sharded_gptj,
+            col_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=0,
+            verbose=False,
+        )
 
     # check grads
     check_all_grad_tensors(grads_to_check)
@@ -110,14 +139,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         {
             "tp_size": 4,
             "pp_size": 1,
-            "enable_all_optimization": True,
+            "enable_all_optimization": False,
             "use_lazy_init": False,
             "precision": "fp32",
         },
         {
             "tp_size": 2,
             "pp_size": 1,
-            "enable_all_optimization": True,
+            "enable_all_optimization": False,
             "use_lazy_init": False,
             "precision": "fp32",
         },
@@ -125,7 +154,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             "tp_size": 2,
             "pp_size": 2,
             "num_microbatches": 4,
-            "enable_all_optimization": True,
+            "enable_all_optimization": False,
             #'use_lazy_init': True,
             "precision": "fp32",
         },
@@ -154,7 +183,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
 def run_gptj_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -189,7 +224,13 @@ def run_gptj_test(test_config):
 def run_gptj_3d_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -198,15 +239,30 @@ def run_gptj_3d_test(test_config):
 
 def check_gptj(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_gptj_test()
 
 
 def check_gptj_3d(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_gptj_3d_test()
 
+
 @pytest.mark.skip("TODO check_gptj has something wrong.")
 @pytest.mark.dist
 @rerun_if_address_is_in_use()
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index c7edcfb35..126ff23a9 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -112,7 +112,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         {
             "tp_size": 4,
             "pp_size": 1,
-            "enable_all_optimization": True,
+            "enable_all_optimization": False,
             "use_lazy_init": False,
             "precision": "fp32",
         },
@@ -124,7 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             "use_lazy_init": False,
             "precision": "fp32",
         },
-        {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
+        {"tp_size": 2, "pp_size": 1, "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32"},
         {
             "tp_size": 2,
             "pp_size": 1,
diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py
index d21ab264d..523ed879b 100644
--- a/tests/test_shardformer/test_model/test_shard_opt.py
+++ b/tests/test_shardformer/test_model/test_shard_opt.py
@@ -29,7 +29,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
     )
 
     org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
-        org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+        org_model,
+        sharded_model,
+        sharded_optimizer,
+        data_gen_fn,
+        output_transform_fn,
+        criterion,
+        booster,
     )
 
     stage_manager = booster.plugin.stage_manager
@@ -39,7 +45,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
     opt_model = unwrap_model(org_model, "OPTModel", "model")
     shard_opt_model = unwrap_model(sharded_model, "OPTModel", "model")
 
-    row_layer_for_check = ["decoder.layers[0].self_attn.q_proj", "decoder.embed_tokens"]  # 'decoder.embed_tokens'
+    row_layer_for_check = [
+        "decoder.layers[0].self_attn.q_proj",
+        "decoder.embed_tokens",
+    ]  # 'decoder.embed_tokens'
     col_layer_for_check = ["decoder.layers[0].self_attn.out_proj"]
 
     # Save gradient tensors for comparison between the original model and the sharded model.
@@ -50,10 +59,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         else:
             atol, rtol = 4e-2, 4e-2
         row_layer_grads = get_grad_tensors_for_check(
-            opt_model, shard_opt_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
+            opt_model,
+            shard_opt_model,
+            row_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=0,
+            verbose=False,
         )
         col_layer_grads = get_grad_tensors_for_check(
-            opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
+            opt_model,
+            shard_opt_model,
+            col_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=1,
+            verbose=False,
         )
         grads_to_check.update(col_layer_grads)
         grads_to_check.update(row_layer_grads)
@@ -80,7 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         else:
             atol, rtol = 5e-3, 5e-3
         check_weight(
-            opt_model, shard_opt_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
+            opt_model,
+            shard_opt_model,
+            col_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=1,
+            verbose=False,
         )
 
     # check grads
@@ -110,8 +140,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             "use_lazy_init": False,
             "precision": "fp32",
         },
-        {"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
-        {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
+        {
+            "tp_size": 4,
+            "pp_size": 1,
+            "enable_all_optimization": False,
+            "use_lazy_init": False,
+            "precision": "fp32",
+        },
+        {
+            "tp_size": 2,
+            "pp_size": 1,
+            "enable_all_optimization": False,
+            "use_lazy_init": False,
+            "precision": "fp32",
+        },
         {
             "tp_size": 2,
             "pp_size": 1,
@@ -135,7 +177,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
 )
 def run_opt_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_opt")
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -169,7 +217,13 @@ def run_opt_test(test_config):
 def run_opt_3d_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_opt")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -178,13 +232,27 @@ def run_opt_3d_test(test_config):
 
 def check_OPTModel(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_opt_test()
 
 
 def check_opt_3d(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_opt_3d_test()
 
 
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index 22c201458..9b22d54d7 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
     )
 
     org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
-        org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+        org_model,
+        sharded_model,
+        sharded_optimizer,
+        data_gen_fn,
+        output_transform_fn,
+        criterion,
+        booster,
     )
 
     stage_manager = booster.plugin.stage_manager
@@ -71,7 +77,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
     else:
         atol, rtol = 5e-3, 5e-3
     if stage_manager is None or stage_manager.is_first_stage():
-        check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
+        check_weight(
+            t5,
+            sharded_t5,
+            row_layer_for_check,
+            tp_group,
+            atol=atol,
+            rtol=rtol,
+            dim=0,
+            verbose=False,
+        )
 
     # check grads
     check_all_grad_tensors(grads_to_check)
@@ -104,7 +119,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
         {
             "tp_size": 4,
             "pp_size": 1,
-            "enable_all_optimization": True,
+            "enable_all_optimization": False,
             "use_lazy_init": False,
             "precision": "fp32",
         },
@@ -117,7 +132,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
             "use_lazy_init": False,
             "precision": "fp32",
         },
-        {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
         {
             "tp_size": 2,
             "pp_size": 1,
@@ -144,7 +158,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
 def run_t5_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_t5")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         # skip 4-stage pp test for t5_encoder
         if test_config["pp_size"] > 2 and name == "transformers_t5_encoder_model":
             continue
@@ -185,7 +205,13 @@ def run_t5_test(test_config):
 def run_t5_3d_test(test_config):
     sub_model_zoo = model_zoo.get_sub_registry("transformers_t5")
 
-    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+    for name, (
+        model_fn,
+        data_gen_fn,
+        output_transform_fn,
+        loss_fn,
+        _,
+    ) in sub_model_zoo.items():
         check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
 
     clear_layout_converter()
@@ -194,13 +220,27 @@ def run_t5_3d_test(test_config):
 
 def check_t5(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_t5_test()
 
 
 def check_t5_3d(rank, world_size, port):
     disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    colossalai.launch(
+        config={},
+        rank=rank,
+        world_size=world_size,
+        host="localhost",
+        port=port,
+        backend="nccl",
+    )
     run_t5_3d_test()
 
 
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
deleted file mode 100644
index 3ec170004..000000000
--- a/tests/test_utils/test_flash_attention.py
+++ /dev/null
@@ -1,167 +0,0 @@
-import math
-
-import pytest
-import torch
-from einops import rearrange
-
-from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN
-from colossalai.testing import clear_cache_before_run, parameterize
-
-if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
-    from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
-
-DTYPE = [torch.float16, torch.bfloat16, torch.float32]
-
-
-def attention_ref(q, k, v, attn_mask=None, causal=False):
-    """
-    attention output of the control group
-    """
-    dtype_og = q.dtype
-    seqlen_q, seqlen_k = q.shape[1], k.shape[1]
-    d = q.shape[-1]
-    scale = 1.0 / math.sqrt(d)
-    scores = torch.einsum("bthd,bshd->bhts", q * scale, k)
-
-    if attn_mask is not None:
-        scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
-    if causal:
-        causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1)
-        scores.masked_fill_(causal_mask, float("-inf"))
-    attention = torch.softmax(scores, dim=-1)
-
-    output = torch.einsum("bhts,bshd->bthd", attention, v)
-    output = rearrange(output, "b s h d -> b s (h d)")
-
-    # Modify the data at the positions of the mask to 0
-    if attn_mask is not None:
-        output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1"), 0.0)
-
-    return output.to(dtype=dtype_og)
-
-
-@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
-@clear_cache_before_run()
-@parameterize("proj_shape", [(6, 8, 4, 16)])
-@parameterize("dtype", DTYPE)
-@parameterize("dropout", [0.0])
-def test_attention_gpt(proj_shape, dtype, dropout):
-    (B, S, H, D_HEAD) = proj_shape
-    D = H * D_HEAD
-
-    q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-    k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-    v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-
-    mask = [torch.ones(S - i, dtype=torch.bool, device="cuda") for i in range(B)]
-    mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
-
-    attn = ColoAttention(D, H, dropout=dropout)
-    y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)
-
-    assert list(y.shape) == [B, S, D]
-
-    out_ref = attention_ref(q, k, v, mask, causal=True)
-
-    # check gradients
-    dy = torch.rand_like(y)
-    grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
-    grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
-
-    torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
-    torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
-    torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
-    torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
-
-
-@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
-@clear_cache_before_run()
-@parameterize("proj_shape", [(6, 8, 4, 16)])
-@parameterize("dtype", DTYPE)
-@parameterize("dropout", [0.0])
-def test_attention_bert(proj_shape, dtype, dropout):
-    (B, S, H, D_HEAD) = proj_shape
-    D = H * D_HEAD
-
-    q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-    k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-    v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-
-    # attention mask of shape [B, S] with zero padding to max length S
-    mask = torch.randint(0, 2, (B, S), dtype=torch.bool, device="cuda")
-
-    attn = ColoAttention(D, H, dropout=dropout)
-    y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
-
-    assert list(y.shape) == [B, S, D]
-
-    out_ref = attention_ref(q, k, v, mask, causal=False)
-
-    dy = torch.rand_like(y)
-    grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
-    grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
-
-    torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
-    torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
-    torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
-    torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
-
-
-@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
-@clear_cache_before_run()
-@parameterize("proj_shape", [(6, 8, 4, 16)])
-@parameterize("dtype", DTYPE)
-@parameterize("dropout", [0.0])
-def test_attention_no_mask(proj_shape, dtype, dropout):
-    (B, S, H, D_HEAD) = proj_shape
-    D = H * D_HEAD
-
-    q = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-    k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-    v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-
-    attn = ColoAttention(D, H, dropout=dropout)
-    y = attn(q, k, v)
-
-    assert list(y.shape) == [B, S, D]
-
-    out_ref = attention_ref(q, k, v, None, causal=False)
-
-    dy = torch.rand_like(y)
-    grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
-    grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
-
-    torch.allclose(y, out_ref, atol=1e-7), f"{(y - out_ref).abs().max()}"
-    torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
-    torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
-    torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"
-
-
-@pytest.mark.skipif(not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN, reason="xformers is not available")
-@clear_cache_before_run()
-@parameterize("proj_shape", [(6, 24, 8, 4, 16)])
-@parameterize("dtype", DTYPE)
-@parameterize("dropout", [0.0])
-def test_cross_attention(proj_shape, dtype, dropout):
-    (B, S, T, H, D_HEAD) = proj_shape
-    D = H * D_HEAD
-
-    q = torch.randn((B, T, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-    k = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-    v = torch.randn((B, S, H, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
-
-    attn = ColoAttention(D, H, dropout=dropout)
-    y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
-
-    assert list(y.shape) == [B, T, D]
-
-    out_ref = attention_ref(q, k, v, None, causal=True)
-
-    dy = torch.rand_like(y)
-    grad_q, grad_k, grad_v = torch.autograd.grad(y, (q, k, v), dy)
-    grad_ref_q, grad_ref_k, grad_ref_v = torch.autograd.grad(out_ref, (q, k, v), dy)
-
-    torch.allclose(y, out_ref, atol=1e-18), f"{(y - out_ref).abs().max()}"
-    torch.allclose(grad_q, grad_ref_q, atol=1e-7), f"{(grad_q - grad_ref_q).abs().max()}"
-    torch.allclose(grad_k, grad_ref_k, atol=1e-7), f"{(grad_k - grad_ref_k).abs().max()}"
-    torch.allclose(grad_v, grad_ref_v, atol=1e-7), f"{(grad_v - grad_ref_v).abs().max()}"