From c54c4fcd15b70830e2efe00df0a6087b9ce5f6b1 Mon Sep 17 00:00:00 2001
From: botbw <wang1570@e.ntu.edu.sg>
Date: Tue, 10 Sep 2024 17:30:53 +0800
Subject: [PATCH] [hotfix] moe hybrid parallelism benchmark & follow-up fix
 (#6048)

* [example] pass use_fp8_comm flag to all plugins

* [example] add mixtral benchmark

* [moe] refine assertion and check

* [moe] fix mixtral & add more tests

* [moe] consider checking dp * sp group and moe_dp_group

* [mixtral] remove gate tp & add more tests

* [deepseek] fix tp & sp for deepseek

* [mixtral] minor fix

* [deepseek] add deepseek benchmark
---
 .../plugin/moe_hybrid_parallel_plugin.py      |  35 ++-
 colossalai/moe/_operation.py                  |   7 +-
 colossalai/shardformer/modeling/deepseek.py   |  81 +++++-
 colossalai/shardformer/modeling/mixtral.py    |   4 +-
 colossalai/shardformer/policies/deepseek.py   |  36 ++-
 colossalai/shardformer/policies/mixtral.py    |  22 +-
 examples/language/deepseek/benchmark.py       | 271 ++++++++++++++++++
 examples/language/deepseek/data_utils.py      |   1 +
 examples/language/deepseek/model_utils.py     |   1 +
 .../deepseek/performance_evaluator.py         |   1 +
 examples/language/deepseek/test_ci.sh         |   0
 examples/language/llama/benchmark.py          |   6 +-
 examples/language/mixtral/benchmark.py        | 259 +++++++++++++++++
 examples/language/mixtral/data_utils.py       |   1 +
 examples/language/mixtral/model_utils.py      |   1 +
 .../language/mixtral/performance_evaluator.py |   1 +
 examples/language/mixtral/test_ci.sh          |   0
 tests/test_moe/moe_utils.py                   |  71 ++++-
 tests/test_moe/test_moe_checkpoint.py         |   4 +-
 .../test_model/test_shard_deepseek.py         | 102 +++++--
 .../test_model/test_shard_mixtral.py          | 102 +++++--
 21 files changed, 907 insertions(+), 99 deletions(-)
 create mode 100644 examples/language/deepseek/benchmark.py
 create mode 120000 examples/language/deepseek/data_utils.py
 create mode 120000 examples/language/deepseek/model_utils.py
 create mode 120000 examples/language/deepseek/performance_evaluator.py
 create mode 100755 examples/language/deepseek/test_ci.sh
 create mode 100644 examples/language/mixtral/benchmark.py
 create mode 120000 examples/language/mixtral/data_utils.py
 create mode 120000 examples/language/mixtral/model_utils.py
 create mode 120000 examples/language/mixtral/performance_evaluator.py
 create mode 100755 examples/language/mixtral/test_ci.sh

diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 74d35f5c5..2324a5239 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -64,13 +64,18 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
         forced_dtype: Optional[torch.dtype] = None,
         overlap_allgather: bool = False,
     ):
-        pg_param_list = {
-            dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
-            moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
-        }
+        if dp_process_group is moe_dp_group:
+            pg_param_list = {
+                dp_process_group: list(model.parameters()),
+            }
+        else:
+            pg_param_list = {
+                dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
+                moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
+            }
 
-        if len(pg_param_list[dp_process_group]) == 0 or len(pg_param_list[moe_dp_group]) == 0:
-            raise ValueError("No parameters found in dp_process_group or moe_dp_group")
+        if len(pg_param_list[moe_dp_group]) == 0:
+            raise ValueError("No parameters found in moe_dp_group, please consider using HybridParallelPlugin instead")
 
         super().__init__(
             model=model,
@@ -407,6 +412,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
                 and self.enable_sequence_parallelism
                 and self.sequence_parallelism_mode == "all_to_all"
             )
+
+            # sync gradients across DP * SP ranks
+            if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
+                dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
+            else:
+                dp_group = self.dp_group
+
             if use_ddp:
                 self.logger.warning(
                     f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
@@ -414,17 +426,11 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
                 )
                 self.ddp_config["find_unused_parameters"] = True
 
-                if dist.get_process_group_ranks(self.dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
+                if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
                     raise ValueError(
-                        f"if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(self.dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0"
+                        f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
                     )
 
-            # sync gradients across DP * SP ranks
-            if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
-                dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
-            else:
-                dp_group = self.dp_group
-
             model = HybridParallelModule(
                 module=model,
                 precision=self.precision,
@@ -466,6 +472,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
                         tp_process_group=self.tp_group,
                     )
             else:
+                is_zero = True
                 if self.dp_size <= 1:
                     self.logger.warning(
                         "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
index ba087a03b..62904d90e 100644
--- a/colossalai/moe/_operation.py
+++ b/colossalai/moe/_operation.py
@@ -308,7 +308,7 @@ class EPGradScalerIn(torch.autograd.Function):
         assert len(grad_outputs) == 1
         grad = grad_outputs[0]
         if ctx.ep_size != 1:
-            grad = grad * ctx.ep_size
+            grad.mul_(ctx.ep_size)
         return grad, None
 
 
@@ -328,7 +328,7 @@ class EPGradScalerOut(torch.autograd.Function):
         assert len(grad_outputs) == 1
         grad = grad_outputs[0]
         if ctx.ep_size != 1:
-            grad = grad / ctx.ep_size
+            grad.div_(ctx.ep_size)
         return grad, None
 
 
@@ -449,7 +449,4 @@ def all_to_all_uneven(
     overlap: bool = False,
     fp8_communication: bool = False,
 ):
-    assert (
-        inputs.requires_grad
-    ), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
     return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 7ec390d6a..4b1b82b7c 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -3,7 +3,7 @@ from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.distributed as dist
-import torch.nn as nn
+import torch.functional as F
 from torch.distributed import ProcessGroup
 from torch.nn import CrossEntropyLoss
 from transformers.cache_utils import Cache, DynamicCache
@@ -28,11 +28,13 @@ from colossalai.quantization.fp8 import all_reduce_fp8
 from colossalai.shardformer.layer._operation import (
     all_to_all_comm,
     gather_forward_split_backward,
+    linear_with_async_comm,
     split_forward_gather_backward,
 )
-from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
+from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
 from colossalai.shardformer.shard import ShardConfig
 from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.d_tensor.api import shard_rowwise, sharded_tensor_to_existing_param
 from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
 
 
@@ -58,7 +60,7 @@ class AddAuxiliaryLoss(torch.autograd.Function):
         return grad_output, grad_loss
 
 
-class EPDeepseekMoE(nn.Module):
+class EPDeepseekMoE(ParallelModule):
     def __init__(self):
         raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
 
@@ -214,6 +216,79 @@ class EPDeepseekMoE(nn.Module):
         return output_hidden_states
 
 
+class DeepseekMoEGate_Col(ParallelModule):
+    def parallel_linear(self, hidden_states):
+        assert (
+            hidden_states.shape[-1] == self.weight.shape[-1]
+        ), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
+            hidden_states.shape, self.weight.shape, self.weight.shape[-1]
+        )
+
+        output = linear_with_async_comm(
+            hidden_states, self.weight, None, self.process_group, True, fp8_communication=self.fp8_communication
+        )
+
+        # All-gather across the partitions.
+        output = gather_forward_split_backward(
+            output, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
+        )
+        return output
+
+    def forward(self, hidden_states):
+        bsz, seq_len, h = hidden_states.shape
+        ### compute gating score
+        hidden_states = hidden_states.view(-1, h)
+        logits = self.parallel_linear(hidden_states)
+        if self.scoring_func == "softmax":
+            scores = logits.softmax(dim=-1)
+        else:
+            raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
+
+        ### select top-k experts
+        topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
+
+        ### norm gate to sum 1
+        if self.top_k > 1 and self.norm_topk_prob:
+            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
+            topk_weight = topk_weight / denominator
+
+        ### expert-level computation auxiliary loss
+        if self.training and self.alpha > 0.0:
+            scores_for_aux = scores
+            aux_topk = self.top_k
+            # always compute aux loss based on the naive greedy topk method
+            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
+            if self.seq_aux:
+                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
+                ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
+                ce.scatter_add_(
+                    1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
+                ).div_(seq_len * aux_topk / self.n_routed_experts)
+                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
+            else:
+                mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
+                ce = mask_ce.float().mean(0)
+                Pi = scores_for_aux.mean(0)
+                fi = ce * self.n_routed_experts
+                aux_loss = (Pi * fi).sum() * self.alpha
+        else:
+            aux_loss = None
+
+        return topk_idx, topk_weight, aux_loss
+
+    @staticmethod
+    def from_native_module(
+        module, process_group: ProcessGroup, config, gather_output, fp8_communication
+    ) -> "DeepseekMoEGate_Col":
+        LazyInitContext.materialize(module)
+        module.process_group = process_group
+        module.fp8_communication = fp8_communication
+        sharded_weight = shard_rowwise(module.weight.data, process_group)
+        sharded_tensor_to_existing_param(sharded_weight, module.weight)
+        module.__class__ = DeepseekMoEGate_Col
+        return module
+
+
 class DeepseekPipelineForwards:
     """
     This class serves as a micro library for forward function substitution of Llama models
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 4850ef1b6..0103808dc 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -36,7 +36,7 @@ from colossalai.shardformer.layer._operation import (
     gather_forward_split_backward,
     split_forward_gather_backward,
 )
-from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
+from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row, ParallelModule
 from colossalai.shardformer.shard import ShardConfig
 from colossalai.shardformer.shard.utils import set_tensors_to_none
 from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
@@ -49,7 +49,7 @@ if is_flash_attn_2_available():
     _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
 
 
-class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
+class EPMixtralSparseMoeBlock(ParallelModule):
     def __init__(self, *args, **kwargs):
         raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
 
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 0b8a602d1..bd54e6f2d 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -10,6 +10,7 @@ from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
 from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
 from colossalai.shardformer.layer.linear import Linear1D_Row
 from colossalai.shardformer.modeling.deepseek import (
+    DeepseekMoEGate_Col,
     DeepseekPipelineForwards,
     EPDeepseekMoE,
     get_deepseek_flash_attention_forward,
@@ -56,16 +57,24 @@ class DeepseekPolicy(Policy):
         sp_size = self.shard_config.sequence_parallel_size or None
         sp_group = self.shard_config.sequence_parallel_process_group or None
         sp_partial_derived = sp_mode in ["split_gather", "ring"]
+        tp_size = self.shard_config.tensor_parallel_size
+
+        # modified for both SP and TP
+        num_q_heads = self.model.config.num_attention_heads
+        num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
         if sp_mode == "all_to_all":
+            num_q_heads //= sp_size
             decoder_attribute_replacement = {
-                "num_heads": self.model.config.num_attention_heads // sp_size,
+                "num_heads": num_q_heads,
             }
             if getattr(self.model.config, "num_key_value_heads", False):
-                decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+                num_kv_heads //= sp_size
+                decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
 
             policy[attn_cls] = ModulePolicyDescription(
                 attribute_replacement=decoder_attribute_replacement,
             )
+
         if self.shard_config.enable_sequence_parallelism:
             if self.pipeline_stage_manager is not None:
                 # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism
@@ -97,6 +106,7 @@ class DeepseekPolicy(Policy):
         else:
             if self.tie_weight:
                 embedding_cls = PaddingEmbedding
+
         if self.shard_config.enable_tensor_parallelism:
             # tensor parallelism for non-moe params
             assert (
@@ -107,10 +117,15 @@ class DeepseekPolicy(Policy):
             ), f"The number of key_value heads must be divisible by tensor parallel size."
             decoder_attribute_replacement = {
                 "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
-                "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
-                "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
-                // self.shard_config.tensor_parallel_size,
             }
+            num_q_heads //= tp_size
+            decoder_attribute_replacement = {
+                "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+                "self_attn.num_heads": num_q_heads,
+            }
+            if num_kv_heads:
+                num_kv_heads //= tp_size
+                decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
 
             policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
                 attribute_replacement=decoder_attribute_replacement,
@@ -135,8 +150,19 @@ class DeepseekPolicy(Policy):
                         target_module=Linear1D_Row,
                         kwargs={"fp8_communication": self.shard_config.fp8_communication},
                     ),
+                    SubModuleReplacementDescription(
+                        suffix="mlp.gate",
+                        target_module=DeepseekMoEGate_Col,
+                        kwargs={
+                            "gather_output": True,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                            "config": self.model.config,
+                        },
+                        ignore_if_not_exist=True,
+                    ),
                 ],
             )
+
         if embedding_cls is not None:
             self.append_or_create_submodule_replacement(
                 description=SubModuleReplacementDescription(
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 3a373889c..9f03319e7 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -51,12 +51,20 @@ class MixtralPolicy(Policy):
         sp_size = self.shard_config.sequence_parallel_size or None
         sp_group = self.shard_config.sequence_parallel_process_group or None
         sp_partial_derived = sp_mode in ["split_gather", "ring"]
+        tp_size = self.shard_config.tensor_parallel_size
+
+        # modified for both SP and TP
+        num_q_heads = self.model.config.num_attention_heads
+        num_kv_heads = getattr(self.model.config, "num_key_value_heads", None)
+
         if sp_mode == "all_to_all":
+            num_q_heads //= sp_size
             decoder_attribute_replacement = {
-                "num_heads": self.model.config.num_attention_heads // sp_size,
+                "num_heads": num_q_heads,
             }
             if getattr(self.model.config, "num_key_value_heads", False):
-                decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+                num_kv_heads //= sp_size
+                decoder_attribute_replacement["num_key_value_heads"] = num_kv_heads
 
             policy[attn_cls] = ModulePolicyDescription(
                 attribute_replacement=decoder_attribute_replacement,
@@ -101,12 +109,14 @@ class MixtralPolicy(Policy):
             assert (
                 self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
             ), f"The number of key_value heads must be divisible by tensor parallel size."
+            num_q_heads //= tp_size
             decoder_attribute_replacement = {
                 "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
-                "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
-                "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
-                // self.shard_config.tensor_parallel_size,
+                "self_attn.num_heads": num_q_heads,
             }
+            if num_kv_heads:
+                num_kv_heads //= tp_size
+                decoder_attribute_replacement["self_attn.num_key_value_heads"] = num_kv_heads
 
             policy[MixtralDecoderLayer] = ModulePolicyDescription(
                 attribute_replacement=decoder_attribute_replacement,
@@ -131,7 +141,7 @@ class MixtralPolicy(Policy):
                         target_module=Linear1D_Row,
                         kwargs={"fp8_communication": self.shard_config.fp8_communication},
                     ),
-                    SubModuleReplacementDescription(  # or replicate?
+                    SubModuleReplacementDescription(
                         suffix="block_sparse_moe.gate",
                         target_module=Linear1D_Col,
                         kwargs={"gather_output": True, "fp8_communication": self.shard_config.fp8_communication},
diff --git a/examples/language/deepseek/benchmark.py b/examples/language/deepseek/benchmark.py
new file mode 100644
index 000000000..fef181e71
--- /dev/null
+++ b/examples/language/deepseek/benchmark.py
@@ -0,0 +1,271 @@
+# modified from mixtral benchmark
+import argparse
+import resource
+import time
+import warnings
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from data_utils import RandomDataset
+from model_utils import format_numel_str, get_model_numel
+from performance_evaluator import PerformanceEvaluator, get_profile_context
+from tqdm import tqdm
+from transformers import AutoConfig, AutoModelForCausalLM
+
+import colossalai
+from colossalai.accelerator import get_accelerator
+from colossalai.booster import Booster
+from colossalai.booster.plugin import MoeHybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.shardformer import PipelineGradientCheckpointConfig
+
+warnings.filterwarnings("ignore")
+# ==============================
+# Constants
+# ==============================
+
+# We have lots of llamas for your choice!
+MODEL_CONFIGS = {
+    "100m": lambda: AutoConfig.from_pretrained(
+        "deepseek-ai/deepseek-moe-16b-base",
+        max_position_embeddings=4096,
+        num_hidden_layers=1,
+        num_attention_heads=32,
+        intermediate_size=512,
+        moe_intermediate_size=128,
+        hidden_size=512,
+        n_routed_experts=8,
+        n_shared_experts=4,
+        num_experts_per_tok=2,
+        first_k_dense_replace=0,
+        attn_implementation="flash_attention_2",
+        trust_remote_code=True,
+    ),
+    "7b": lambda: AutoConfig.from_pretrained(
+        "deepseek-ai/deepseek-moe-16b-base",
+        max_position_embeddings=4096,
+        num_hidden_layers=13,
+        attn_implementation="flash_attention_2",
+        trust_remote_code=True,
+    ),
+    "14b": lambda: AutoConfig.from_pretrained(
+        "deepseek-ai/deepseek-moe-16b-base",
+        max_position_embeddings=4096,
+        num_hidden_layers=26,
+        attn_implementation="flash_attention_2",
+        trust_remote_code=True,
+    ),
+}
+
+
+def main():
+    # ==============================
+    # Parse Arguments
+    # ==============================
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration")
+    parser.add_argument(
+        "-p",
+        "--plugin",
+        choices=["3d"],
+        default="3d",
+        help="Choose which plugin to use",
+    )
+    parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
+    parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
+    parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
+    parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
+    parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
+    parser.add_argument(
+        "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
+    )
+    parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
+    parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
+    parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
+    parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
+    parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
+    parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
+    parser.add_argument("--ep", type=int, default=1, help="Expert parallel size")
+    parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
+    parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
+    parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
+    parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
+    parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
+    parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
+
+    parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
+    parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
+    parser.add_argument("--profile", action="store_true", help="Profile the code")
+    parser.add_argument(
+        "--nsys",
+        action="store_true",
+        help="Use nsys for profiling. \
+        You should put something like this before colossalai launch: \
+        nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
+    )
+    parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
+    parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
+    parser.add_argument("--no_cache", action="store_true")
+    parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
+    parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
+    parser.add_argument("--overlap_allgather", action="store_true")
+    parser.add_argument(
+        "--sp_mode",
+        default="all_to_all",
+        choices=["all_to_all"],
+        help="Sequence parallelism mode",
+    )
+    parser.add_argument("--debug", action="store_true", help="Enable debug mode")
+    args = parser.parse_args()
+
+    colossalai.launch_from_torch()
+    coordinator = DistCoordinator()
+
+    # ckpt config for LLaMA3-70B on 64 H100 GPUs
+    hybrid_kwargs = (
+        {
+            "gradient_checkpoint_config": PipelineGradientCheckpointConfig(
+                num_ckpt_layers_per_stage=[19, 19, 19, 13],
+            ),
+            "num_layers_per_stage": [19, 20, 20, 21],
+            "pp_style": "interleaved",
+        }
+        if args.custom_ckpt
+        else {}
+    )
+
+    # ==============================
+    # Initialize Booster
+    # ==============================
+    if args.plugin == "3d":
+        plugin = MoeHybridParallelPlugin(
+            ep_size=args.ep,
+            tp_size=args.tp,
+            pp_size=args.pp,
+            pp_style=args.pp_style,
+            num_model_chunks=args.n_chunks,
+            zero_stage=args.zero,
+            sp_size=args.sp,
+            sequence_parallelism_mode=args.sp_mode,
+            enable_sequence_parallelism=args.sp > 1,
+            enable_fused_normalization=torch.cuda.is_available(),
+            enable_flash_attention=args.xformers,
+            microbatch_size=args.mbs,
+            precision="bf16",
+            enable_metadata_cache=not args.no_cache,
+            overlap_allgather=args.overlap_allgather,
+            use_fp8=args.use_fp8,
+            fp8_communication=args.use_fp8_comm,
+            **hybrid_kwargs,
+        )
+    else:
+        raise ValueError(f"Unknown plugin {args.plugin}")
+
+    booster = Booster(plugin=plugin)
+
+    # ==============================
+    # Initialize Dataset and Dataloader
+    # ==============================
+    dp_size = getattr(plugin, "dp_size", coordinator.world_size)
+
+    config = MODEL_CONFIGS[args.config]()
+
+    torch.cuda.manual_seed(42)
+
+    dataset = RandomDataset(
+        num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
+    )
+    dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
+
+    # ==============================
+    # Initialize Model and Optimizer
+    # ==============================
+    init_ctx = (
+        LazyInitContext(default_device=get_accelerator().get_current_device())
+        if isinstance(plugin, MoeHybridParallelPlugin)
+        else nullcontext()
+    )
+
+    with init_ctx:
+        model = AutoModelForCausalLM.from_config(config, trust_remote_code=True).to(torch.bfloat16)
+
+    if args.grad_checkpoint:
+        model.gradient_checkpointing_enable()
+
+    model_numel = get_model_numel(model)
+    coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
+    performance_evaluator = PerformanceEvaluator(
+        model_numel,
+        model.config.num_hidden_layers,
+        model.config.hidden_size,
+        model.config.vocab_size,
+        args.grad_checkpoint,
+        args.ignore_steps,
+        dp_world_size=dp_size,
+    )
+
+    optimizer = HybridAdam(model.parameters())
+    torch.set_default_dtype(torch.bfloat16)
+    model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
+
+    torch.set_default_dtype(torch.float)
+    coordinator.print_on_master(
+        f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
+    )
+    coordinator.print_on_master(
+        f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
+    )
+
+    with get_profile_context(
+        args.profile,
+        args.ignore_steps,
+        1,  # avoid creating massive log files
+        save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
+        nsys=args.nsys,
+    ) as prof:  # , distributed_debug_mode(10, enable=True):
+        if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
+            data_iter = iter(dataloader)
+            for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
+                performance_evaluator.on_step_start(step)
+                outputs = booster.execute_pipeline(
+                    data_iter,
+                    model,
+                    criterion=lambda outputs, inputs: outputs[0],
+                    optimizer=optimizer,
+                    return_loss=True,
+                )
+                loss = outputs["loss"]
+                if dist.get_rank() == dist.get_world_size() - 1:
+                    print(f"Step {step} loss: {loss}")
+                optimizer.step()
+                optimizer.zero_grad()
+
+                performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
+                prof.step()
+                print(f"rank {dist.get_rank()} step {step} passed")
+        else:
+            for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
+                performance_evaluator.on_step_start(step)
+                outputs = model(**batch)
+                loss = outputs[0]
+                del outputs  # free memory
+
+                if dist.get_rank() == dist.get_world_size() - 1:
+                    print(f"Step {step} loss: {loss}")
+
+                booster.backward(loss, optimizer)
+                optimizer.step()
+                optimizer.zero_grad()
+
+                performance_evaluator.on_step_end(**batch)
+                prof.step()
+
+    performance_evaluator.on_fit_end()
+    coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/language/deepseek/data_utils.py b/examples/language/deepseek/data_utils.py
new file mode 120000
index 000000000..2da9822df
--- /dev/null
+++ b/examples/language/deepseek/data_utils.py
@@ -0,0 +1 @@
+../data_utils.py
\ No newline at end of file
diff --git a/examples/language/deepseek/model_utils.py b/examples/language/deepseek/model_utils.py
new file mode 120000
index 000000000..73c6818a8
--- /dev/null
+++ b/examples/language/deepseek/model_utils.py
@@ -0,0 +1 @@
+../model_utils.py
\ No newline at end of file
diff --git a/examples/language/deepseek/performance_evaluator.py b/examples/language/deepseek/performance_evaluator.py
new file mode 120000
index 000000000..f4736354b
--- /dev/null
+++ b/examples/language/deepseek/performance_evaluator.py
@@ -0,0 +1 @@
+../performance_evaluator.py
\ No newline at end of file
diff --git a/examples/language/deepseek/test_ci.sh b/examples/language/deepseek/test_ci.sh
new file mode 100755
index 000000000..e69de29bb
diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py
index bb14378ad..0e88fabf1 100644
--- a/examples/language/llama/benchmark.py
+++ b/examples/language/llama/benchmark.py
@@ -105,7 +105,7 @@ def main():
     parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
     parser.add_argument("--no_cache", action="store_true")
     parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
-    parser.add_argument("--use_fp8", action="store_true")
+    parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
     parser.add_argument("--overlap_allgather", action="store_true")
     parser.add_argument(
         "--sp_mode",
@@ -151,6 +151,7 @@ def main():
             max_prefetch=args.prefetch_num,
             enable_async_reduce=not args.disable_async_reduce,
             use_fp8=args.use_fp8,
+            fp8_communication=args.use_fp8_comm,
         )
     elif args.plugin == "gemini_auto":
         plugin = GeminiPlugin(
@@ -164,6 +165,7 @@ def main():
             enable_async_reduce=not args.disable_async_reduce,
             enable_flash_attention=args.xformers,
             use_fp8=args.use_fp8,
+            fp8_communication=args.use_fp8_comm,
         )
     elif args.plugin == "fsdp":
         if use_empty_init:
@@ -224,6 +226,7 @@ def main():
             enable_metadata_cache=not args.no_cache,
             overlap_allgather=args.overlap_allgather,
             use_fp8=args.use_fp8,
+            fp8_communication=args.use_fp8_comm,
             **hybrid_kwargs,
         )
     elif args.plugin == "3d_cpu":
@@ -241,6 +244,7 @@ def main():
             precision="bf16",
             overlap_p2p=args.overlap,
             use_fp8=args.use_fp8,
+            fp8_communication=args.use_fp8_comm,
         )
     else:
         raise ValueError(f"Unknown plugin {args.plugin}")
diff --git a/examples/language/mixtral/benchmark.py b/examples/language/mixtral/benchmark.py
new file mode 100644
index 000000000..bb2a32d01
--- /dev/null
+++ b/examples/language/mixtral/benchmark.py
@@ -0,0 +1,259 @@
+# modified from llama benchmark
+import argparse
+import resource
+import time
+import warnings
+from contextlib import nullcontext
+
+import torch
+import torch.distributed as dist
+from data_utils import RandomDataset
+from model_utils import format_numel_str, get_model_numel
+from performance_evaluator import PerformanceEvaluator, get_profile_context
+from tqdm import tqdm
+from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
+
+import colossalai
+from colossalai.accelerator import get_accelerator
+from colossalai.booster import Booster
+from colossalai.booster.plugin import MoeHybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.shardformer import PipelineGradientCheckpointConfig
+
+warnings.filterwarnings("ignore")
+# ==============================
+# Constants
+# ==============================
+
+# We have lots of llamas for your choice!
+MODEL_CONFIGS = {
+    "100m": MixtralConfig(
+        max_position_embeddings=4096,
+        num_hidden_layers=4,
+        num_attention_heads=32,
+        intermediate_size=768,
+        hidden_size=768,
+        attn_implementation="flash_attention_2",
+    ),
+    "7b": MixtralConfig(
+        max_position_embeddings=4096,
+        num_hidden_layers=5,
+        attn_implementation="flash_attention_2",
+    ),
+    "14b": MixtralConfig(
+        max_position_embeddings=4096,
+        num_hidden_layers=10,
+        attn_implementation="flash_attention_2",
+    ),
+}
+
+
+def main():
+    # ==============================
+    # Parse Arguments
+    # ==============================
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-c", "--config", type=str, default="100m", help="Model configuration")
+    parser.add_argument(
+        "-p",
+        "--plugin",
+        choices=["3d"],
+        default="3d",
+        help="Choose which plugin to use",
+    )
+    parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch size")
+    parser.add_argument("-s", "--num_steps", type=int, default=5, help="Number of steps to run")
+    parser.add_argument("-i", "--ignore_steps", type=int, default=2, help="Number of steps to ignore")
+    parser.add_argument("-g", "--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
+    parser.add_argument("-l", "--max_length", type=int, default=4096, help="Max sequence length")
+    parser.add_argument(
+        "-w", "--warmup_ratio", type=float, default=0.8, help="warm up ratio of non-model data. Only for gemini-auto"
+    )
+    parser.add_argument("-m", "--memory_limit", type=int, help="Gemini memory limit in mb")
+    parser.add_argument("-x", "--xformers", action="store_true", help="Use xformers")
+    parser.add_argument("--shard_param_frac", type=float, default=1.0, help="Shard param fraction. Only for gemini")
+    parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
+    parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
+    parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
+    parser.add_argument("--ep", type=int, default=1, help="Expert parallel size")
+    parser.add_argument("--sp", type=int, default=1, help="Sequence parallel size")
+    parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
+    parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
+    parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
+    parser.add_argument("--zero", type=int, default=1, help="Zero Stage when hybrid plugin is enabled")
+    parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
+
+    parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved"])
+    parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval)
+    parser.add_argument("--profile", action="store_true", help="Profile the code")
+    parser.add_argument(
+        "--nsys",
+        action="store_true",
+        help="Use nsys for profiling. \
+        You should put something like this before colossalai launch: \
+        nsys profile -w true -t cuda,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --capture-range-end=stop --cudabacktrace=true -x true --python-backtrace=cuda -o prof_out",
+    )
+    parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation")
+    parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
+    parser.add_argument("--no_cache", action="store_true")
+    parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
+    parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
+    parser.add_argument("--overlap_allgather", action="store_true")
+    parser.add_argument(
+        "--sp_mode",
+        default="all_to_all",
+        choices=["all_to_all"],
+        help="Sequence parallelism mode",
+    )
+    parser.add_argument("--debug", action="store_true", help="Enable debug mode")
+    args = parser.parse_args()
+
+    colossalai.launch_from_torch()
+    coordinator = DistCoordinator()
+
+    # ckpt config for LLaMA3-70B on 64 H100 GPUs
+    hybrid_kwargs = (
+        {
+            "gradient_checkpoint_config": PipelineGradientCheckpointConfig(
+                num_ckpt_layers_per_stage=[19, 19, 19, 13],
+            ),
+            "num_layers_per_stage": [19, 20, 20, 21],
+            "pp_style": "interleaved",
+        }
+        if args.custom_ckpt
+        else {}
+    )
+
+    # ==============================
+    # Initialize Booster
+    # ==============================
+    if args.plugin == "3d":
+        plugin = MoeHybridParallelPlugin(
+            ep_size=args.ep,
+            tp_size=args.tp,
+            pp_size=args.pp,
+            pp_style=args.pp_style,
+            num_model_chunks=args.n_chunks,
+            zero_stage=args.zero,
+            sp_size=args.sp,
+            sequence_parallelism_mode=args.sp_mode,
+            enable_sequence_parallelism=args.sp > 1,
+            enable_fused_normalization=torch.cuda.is_available(),
+            enable_flash_attention=args.xformers,
+            microbatch_size=args.mbs,
+            precision="bf16",
+            enable_metadata_cache=not args.no_cache,
+            overlap_allgather=args.overlap_allgather,
+            use_fp8=args.use_fp8,
+            fp8_communication=args.use_fp8_comm,
+            **hybrid_kwargs,
+        )
+    else:
+        raise ValueError(f"Unknown plugin {args.plugin}")
+
+    booster = Booster(plugin=plugin)
+
+    # ==============================
+    # Initialize Dataset and Dataloader
+    # ==============================
+    dp_size = getattr(plugin, "dp_size", coordinator.world_size)
+
+    if args.config in MODEL_CONFIGS:
+        config = MODEL_CONFIGS[args.config]
+    else:
+        config = MixtralConfig.from_pretrained(args.config, trust_remote_code=True)
+    torch.cuda.manual_seed(42)
+
+    dataset = RandomDataset(
+        num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
+    )
+    dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, seed=42)
+
+    # ==============================
+    # Initialize Model and Optimizer
+    # ==============================
+    init_ctx = (
+        LazyInitContext(default_device=get_accelerator().get_current_device())
+        if isinstance(plugin, MoeHybridParallelPlugin)
+        else nullcontext()
+    )
+
+    with init_ctx:
+        model = MixtralForCausalLM(config=config).to(torch.bfloat16)
+
+    if args.grad_checkpoint:
+        model.gradient_checkpointing_enable()
+
+    model_numel = get_model_numel(model)
+    coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
+    performance_evaluator = PerformanceEvaluator(
+        model_numel,
+        model.config.num_hidden_layers,
+        model.config.hidden_size,
+        model.config.vocab_size,
+        args.grad_checkpoint,
+        args.ignore_steps,
+        dp_world_size=dp_size,
+    )
+
+    optimizer = HybridAdam(model.parameters())
+    torch.set_default_dtype(torch.bfloat16)
+    model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
+
+    torch.set_default_dtype(torch.float)
+    coordinator.print_on_master(
+        f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
+    )
+    coordinator.print_on_master(
+        f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
+    )
+
+    with get_profile_context(
+        args.profile,
+        args.ignore_steps,
+        1,  # avoid creating massive log files
+        save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
+        nsys=args.nsys,
+    ) as prof:
+        if isinstance(plugin, MoeHybridParallelPlugin) and args.pp > 1:
+            data_iter = iter(dataloader)
+            for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
+                performance_evaluator.on_step_start(step)
+                outputs = booster.execute_pipeline(
+                    data_iter,
+                    model,
+                    criterion=lambda outputs, inputs: outputs[0],
+                    optimizer=optimizer,
+                    return_loss=True,
+                )
+                loss = outputs["loss"]
+                if dist.get_rank() == dist.get_world_size() - 1:
+                    print(f"Step {step} loss: {loss}")
+                optimizer.step()
+                optimizer.zero_grad()
+
+                performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
+                prof.step()
+        else:
+            for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
+                performance_evaluator.on_step_start(step)
+                outputs = model(**batch)
+                loss = outputs[0]
+                del outputs  # free memory
+
+                if dist.get_rank() == dist.get_world_size() - 1:
+                    print(f"Step {step} loss: {loss}")
+                booster.backward(loss, optimizer)
+                optimizer.step()
+                optimizer.zero_grad()
+
+                performance_evaluator.on_step_end(**batch)
+                prof.step()
+    performance_evaluator.on_fit_end()
+    coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/language/mixtral/data_utils.py b/examples/language/mixtral/data_utils.py
new file mode 120000
index 000000000..2da9822df
--- /dev/null
+++ b/examples/language/mixtral/data_utils.py
@@ -0,0 +1 @@
+../data_utils.py
\ No newline at end of file
diff --git a/examples/language/mixtral/model_utils.py b/examples/language/mixtral/model_utils.py
new file mode 120000
index 000000000..73c6818a8
--- /dev/null
+++ b/examples/language/mixtral/model_utils.py
@@ -0,0 +1 @@
+../model_utils.py
\ No newline at end of file
diff --git a/examples/language/mixtral/performance_evaluator.py b/examples/language/mixtral/performance_evaluator.py
new file mode 120000
index 000000000..f4736354b
--- /dev/null
+++ b/examples/language/mixtral/performance_evaluator.py
@@ -0,0 +1 @@
+../performance_evaluator.py
\ No newline at end of file
diff --git a/examples/language/mixtral/test_ci.sh b/examples/language/mixtral/test_ci.sh
new file mode 100755
index 000000000..e69de29bb
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
index 8c411a33f..dbcd28ab5 100644
--- a/tests/test_moe/moe_utils.py
+++ b/tests/test_moe/moe_utils.py
@@ -1,4 +1,12 @@
+import os
+import traceback
+from contextlib import contextmanager
+from time import sleep
+from typing import Callable, List, Optional
+
 import torch
+import torch.distributed as dist
+from torch.utils._pytree import tree_map
 
 
 def assert_loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
@@ -25,7 +33,66 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
     return torch.allclose(a, b, rtol=rtol, atol=atol)
 
 
-def check_model_equal(model1, model2):
+def check_model_equal(model1, model2, dtype):
     assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
     for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
-        assert_loose_close(p1, p2, p1.dtype)
+        assert_loose_close(p1, p2, dtype, name=name)
+
+
+@contextmanager
+def distributed_debug_mode(num_stacks: int = 1, funcs_to_patch: Optional[List[Callable]] = None, enable=True):
+    if enable:
+        assert (
+            os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1"
+        ), f"Expect CUDA_LAUNCH_BLOCKING=1, got {os.environ.get('CUDA_LAUNCH_BLOCKING', '0')}"
+    if funcs_to_patch is None:
+        funcs_to_patch = [
+            dist.all_reduce,
+            dist.all_reduce_coalesced,
+            dist.all_gather,
+            dist.all_gather_coalesced,
+            dist.all_gather_into_tensor,
+            dist.all_to_all,
+            dist.all_to_all_single,
+            dist.reduce_scatter,
+        ]
+
+    original_funcs = {}
+    patched_funcs = {}
+
+    def make_patched(func):
+        def patched_func(*args, **kwargs):
+            stack = traceback.format_stack()
+
+            def format_node(node):
+                if isinstance(node, torch.Tensor):
+                    return f"{node.shape}"
+                elif isinstance(node, list):
+                    return f"[{', '.join([format_node(n) for n in node])}]"
+
+                return str(node)
+
+            args_str, kwargs_str = tree_map(format_node, (args, kwargs))
+            en = len(stack) - 1
+            st = max(0, en - num_stacks)
+            dist.barrier()
+            sleep(0.001 * dist.get_rank())
+            print(
+                f"[Rank {dist.get_rank()}-{func.__name__}-{dist.get_process_group_ranks(kwargs.get('group', dist.group.WORLD))}]: Called from {''.join(stack[st:en])}args={args_str} kwargs={kwargs_str}\n"
+            )
+            dist.barrier()
+            return func(*args, **kwargs)
+
+        return patched_func
+
+    if enable:
+        for func in funcs_to_patch:
+            original_funcs[func.__name__] = getattr(dist, func.__name__)
+            patched_funcs[func.__name__] = make_patched(func)
+            setattr(dist, func.__name__, patched_funcs[func.__name__])
+
+    try:
+        yield
+    finally:
+        for func_name, original_func in original_funcs.items():
+            setattr(dist, func_name, original_func)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 89f5d1c64..f3f109192 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -130,7 +130,7 @@ def check_moe_checkpoint(test_config):
         dist.barrier()
         if dist.get_rank() == 0:
             saved_model = model_cls.from_pretrained(model_dir).cuda().to(dtype)
-            check_model_equal(orig_model, saved_model)
+            check_model_equal(orig_model, saved_model, dtype=dtype)
             saved_model.save_pretrained(hf_model_dir)
         dist.barrier()
         # check load model
@@ -138,7 +138,7 @@ def check_moe_checkpoint(test_config):
         new_optimizer = Adam(new_model.parameters(), lr=1e-3)
         new_model, new_optimizer, *_ = booster.boost(model=new_model, optimizer=new_optimizer)
         booster.load_model(new_model, hf_model_dir)
-        check_model_equal(model, new_model)
+        check_model_equal(model, new_model, dtype=dtype)
 
         # check save optimizer
         optimizer.step()
diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py
index 46da4522f..d782a2a09 100644
--- a/tests/test_shardformer/test_model/test_shard_deepseek.py
+++ b/tests/test_shardformer/test_model/test_shard_deepseek.py
@@ -12,43 +12,25 @@ from transformers import AutoConfig, AutoModel
 import colossalai
 from colossalai.booster.booster import Booster
 from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.shardformer.layer.utils import Randomizer
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.testing.random import seed_all
 from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
 
 NUM_BATCH = 8
-NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 2
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
 NUM_LAYERS = 4
 HIDDEN_SIZE_PER_HEAD = 4
-NUM_HEADS = 4
+NUM_HEADS = 8
 TOP_K = 2
 
 
-CHECKED_CONFIG = [  # FOR_WORLD=4
-    (1, 4, 1, 1, 1),
-    (1, 1, 4, 1, 1),
-    (1, 1, 1, 4, 1),
-    (1, 1, 1, 1, 4),
-    (0, 1, 4, 1, 1),
-    (0, 1, 1, 4, 1),
-    (0, 1, 1, 1, 4),
-    (1, 2, 1, 1, 1),
-]
-
-
-@parameterize(
-    "config",
-    [
-        (1, 2, 2, 1, 1),
-        (1, 2, 1, 2, 1),
-        (1, 2, 1, 1, 2),
-    ],
-)
-def run_zero_with_original_model(config: Tuple[int, ...]):
+def run_deepseek_commom(config: Tuple[int, ...]):
+    Randomizer.reset_index()
     stage, ep_size, pp_size, tp_size, sp_size = config
     world_size = dist.get_world_size()
     rank = dist.get_rank()
-    dtype, precision = torch.float16, "fp16"
+    dtype, precision = torch.bfloat16, "bf16"
     torch.cuda.set_device(dist.get_rank())
 
     plugin = MoeHybridParallelPlugin(
@@ -60,11 +42,11 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
         zero_stage=stage,
         enable_sequence_parallelism=sp_size > 1,
         sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
-        enable_flash_attention=sp_size > 1,
         overlap_communication=False,
         initial_scale=1,
         precision=precision,
         find_unused_parameters=True,
+        enable_flash_attention=True,
     )
     dp_size = plugin.dp_size
 
@@ -171,7 +153,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
     dist.barrier()
 
     saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
-    check_model_equal(torch_model, saved_model)
+    check_model_equal(torch_model, saved_model, dtype=dtype)
     dist.barrier()
 
     if rank == world_size - 1:
@@ -180,17 +162,77 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
     print(f"rank {dist.get_rank()} test passed")
 
 
-def run_dist(rank, world_size, port):
+@parameterize(
+    "config",
+    [
+        # DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
+        (0, 1, 4, 1, 1),
+        (0, 1, 1, 4, 1),
+        (0, 1, 2, 2, 1),
+        # zero 1
+        (1, 4, 1, 1, 1),
+        (1, 1, 4, 1, 1),
+        (1, 1, 1, 4, 1),
+        (1, 2, 1, 1, 2),
+        # zero 2
+        (2, 4, 1, 1, 1),
+        (2, 1, 4, 1, 1),
+        (2, 1, 1, 4, 1),
+        (2, 2, 1, 1, 2),
+    ],
+)
+def run_deepseek_test(config: Tuple[int, ...]):
+    run_deepseek_commom(config)
+
+
+@parameterize(
+    "config",
+    [
+        # DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
+        (0, 1, 2, 4, 1),
+        (0, 1, 4, 2, 1),
+        (0, 1, 1, 4, 1),
+        (0, 1, 4, 1, 1),
+        # zero 1:
+        (1, 2, 1, 1, 2),
+        (1, 2, 1, 4, 1),
+        (1, 1, 1, 2, 2),
+        (1, 2, 2, 2, 1),
+        # zero 2
+        (2, 2, 1, 1, 2),
+        (2, 2, 1, 4, 1),
+        (2, 1, 1, 2, 2),
+        (2, 2, 2, 2, 1),
+    ],
+)
+def run_deepseek_3d_test(config: Tuple[int, ...]):
+    run_deepseek_commom(config)
+
+
+def check_deepseek(rank, world_size, port):
     colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
-    run_zero_with_original_model()
+    run_deepseek_test()
+
+
+def check_deepseek_3d(rank, world_size, port):
+    colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    run_deepseek_3d_test()
 
 
 @pytest.mark.dist
 @pytest.mark.parametrize("world_size", [4])
 @rerun_if_address_is_in_use()
 def test_deepseek(world_size):
-    spawn(run_dist, world_size)
+    spawn(check_deepseek, world_size)
+
+
+@pytest.mark.largedist
+@pytest.mark.parametrize("world_size", [8])
+@rerun_if_address_is_in_use()
+def test_deepseek_3d(world_size):
+    spawn(check_deepseek_3d, world_size)
 
 
 if __name__ == "__main__":
-    test_deepseek(world_size=4)
+    test_deepseek(world_size=8)
+    test_deepseek_3d(world_size=8)
diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py
index de09eedcb..940c66cf6 100644
--- a/tests/test_shardformer/test_model/test_shard_mixtral.py
+++ b/tests/test_shardformer/test_model/test_shard_mixtral.py
@@ -13,42 +13,25 @@ from transformers.models.mixtral.modeling_mixtral import MixtralModel
 import colossalai
 from colossalai.booster.booster import Booster
 from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.shardformer.layer.utils import Randomizer
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.testing.random import seed_all
 from tests.test_moe.moe_utils import assert_loose_close, check_model_equal
 
 NUM_BATCH = 8
-NUM_TOK_PER_BATCH, NUM_EXPERTS = 4, 4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 64, 4
 NUM_LAYERS = 4
 HIDDEN_SIZE_PER_HEAD = 4
-NUM_HEADS = 4
-TOP_K = 1
-
-CHECKED_CONFIG = [  # FOR WORLD=4
-    (0, 1, 4, 1, 1),
-    (0, 1, 1, 4, 1),
-    (0, 1, 1, 1, 4),
-    (1, 4, 1, 1, 1),
-    (1, 1, 4, 1, 1),
-    (1, 1, 1, 4, 1),
-    (1, 1, 1, 1, 4),
-    (1, 2, 1, 1, 1),
-]
+NUM_HEADS = 8
+TOP_K = 2
 
 
-@parameterize(
-    "config",
-    [
-        (1, 2, 2, 1, 1),
-        (1, 2, 1, 2, 1),
-        (1, 2, 1, 1, 2),
-    ],
-)
-def run_zero_with_original_model(config: Tuple[int, ...]):
+def run_mixtral_commom(config: Tuple[int, ...]):
+    Randomizer.reset_index()
     stage, ep_size, pp_size, tp_size, sp_size = config
     world_size = dist.get_world_size()
     rank = dist.get_rank()
-    dtype, precision = torch.float16, "fp16"
+    dtype, precision = torch.bfloat16, "bf16"
     torch.cuda.set_device(dist.get_rank())
 
     plugin = MoeHybridParallelPlugin(
@@ -165,7 +148,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
     dist.barrier()
 
     saved_model = MixtralModel.from_pretrained(model_dir).cuda().to(dtype)
-    check_model_equal(torch_model, saved_model)
+    check_model_equal(torch_model, saved_model, dtype=dtype)
     dist.barrier()
 
     if rank == world_size - 1:
@@ -174,17 +157,78 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
     print(f"rank {dist.get_rank()} test passed")
 
 
-def run_dist(rank, world_size, port):
+@parameterize(
+    "config",
+    [
+        # DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
+        (0, 1, 4, 1, 1),
+        (0, 1, 1, 4, 1),
+        (0, 1, 2, 2, 1),
+        # zero 1
+        (1, 4, 1, 1, 1),
+        (1, 1, 4, 1, 1),
+        (1, 1, 1, 4, 1),
+        (1, 2, 1, 1, 2),
+        # zero 2
+        (2, 4, 1, 1, 1),
+        (2, 1, 4, 1, 1),
+        (2, 1, 1, 4, 1),
+        (2, 2, 1, 1, 2),
+    ],
+)
+def run_mixtral_test(config: Tuple[int, ...]):
+    run_mixtral_commom(config)
+
+
+@parameterize(
+    "config",
+    [
+        # DDP: ep == 1 since ep * moe_dp == dp == moe_dp; sp == 1 since sp * dp == moe_dp == dp
+        (0, 1, 2, 4, 1),
+        (0, 1, 4, 2, 1),
+        (0, 1, 1, 4, 1),
+        (0, 1, 4, 1, 1),
+        # zero 1:
+        (1, 2, 1, 1, 2),
+        (1, 2, 1, 4, 1),
+        (1, 1, 1, 2, 2),
+        (1, 2, 2, 2, 1),
+        # zero 2
+        (2, 2, 1, 1, 2),
+        (2, 2, 1, 4, 1),
+        (2, 1, 1, 2, 2),
+        (2, 2, 2, 2, 1),
+    ],
+)
+def run_mixtral_3d_test(config: Tuple[int, ...]):
+    print(f"{config=}")
+    run_mixtral_commom(config)
+
+
+def check_mixtral(rank, world_size, port):
     colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
-    run_zero_with_original_model()
+    run_mixtral_test()
+
+
+def check_mixtral_3d(rank, world_size, port):
+    colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    run_mixtral_3d_test()
 
 
 @pytest.mark.dist
 @pytest.mark.parametrize("world_size", [4])
 @rerun_if_address_is_in_use()
 def test_mixtral(world_size):
-    spawn(run_dist, world_size)
+    spawn(check_mixtral, world_size)
+
+
+@pytest.mark.largedist
+@pytest.mark.parametrize("world_size", [8])
+@rerun_if_address_is_in_use()
+def test_mixtral_3d(world_size):
+    spawn(check_mixtral_3d, world_size)
 
 
 if __name__ == "__main__":
-    test_mixtral(world_size=4)
+    test_mixtral(world_size=8)
+    test_mixtral_3d(world_size=8)