diff --git a/colossalai/shardformer/modeling/deepseek.py b/colossalai/shardformer/modeling/deepseek.py
index 6e79ce144..33fac9b93 100644
--- a/colossalai/shardformer/modeling/deepseek.py
+++ b/colossalai/shardformer/modeling/deepseek.py
@@ -1,21 +1,27 @@
-from typing import List, Optional, Union
+from typing import List, Optional
 
 import torch
 import torch.distributed as dist
 import torch.nn as nn
 from torch.distributed import ProcessGroup
-
-# from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
 from torch.nn import CrossEntropyLoss
 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 from transformers.modeling_outputs import CausalLMOutputWithPast
 from transformers.utils import is_flash_attn_2_available, logging
 
 from colossalai.lazy import LazyInitContext
-from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
+from colossalai.moe._operation import (
+    DPGradScalerIn,
+    DPGradScalerOut,
+    EPGradScalerIn,
+    EPGradScalerOut,
+    all_to_all_uneven,
+)
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer.linear import Linear1D_Col, Linear1D_Row
 from colossalai.shardformer.shard import ShardConfig
 from colossalai.shardformer.shard.utils import set_tensors_to_none
+from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group
 
 
 # copied from modeling_deepseek.py
@@ -42,30 +48,60 @@ class AddAuxiliaryLoss(torch.autograd.Function):
 
 class EPDeepseekMoE(nn.Module):
     def __init__(self):
-        super(EPDeepseekMoE, self).__init__()
+        raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")
 
-    def setup_ep(self, ep_group: ProcessGroup):
-        ep_group = ep_group
-        self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
-        self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
+    def setup_process_groups(
+        self, tp_group: ProcessGroup, moe_dp_group: ProcessGroup, ep_group: ProcessGroup, moe_tp_group: ProcessGroup
+    ):
+        assert tp_group is not None
+        assert moe_dp_group is not None
+        assert ep_group is not None
+        assert moe_tp_group is not None
+
+        self.ep_size = dist.get_world_size(ep_group)
+        self.ep_rank = dist.get_rank(ep_group)
         self.num_experts = self.config.n_routed_experts
         assert self.num_experts % self.ep_size == 0
+
         self.ep_group = ep_group
         self.num_experts_per_ep = self.num_experts // self.ep_size
         self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
         held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
+
         set_tensors_to_none(self.experts, exclude=set(held_experts))
         for p in self.experts.parameters():
-            p.ep_group = ep_group
+            set_moe_tensor_ep_group(p, ep_group)
+
+        # setup moe_dp group
+        self.moe_dp_group = moe_dp_group
+        self.moe_dp_size = moe_dp_group.size()
+
+        # setup global tp group
+        self.tp_group = tp_group
+
+        # setup moe tp group
+        self.moe_tp_group = moe_tp_group
+        if self.moe_tp_group.size() > 1:
+            for expert in held_experts:
+                expert.gate_proj = Linear1D_Col.from_native_module(expert.gate_proj, self.moe_tp_group)
+                expert.up_proj = Linear1D_Col.from_native_module(expert.up_proj, self.moe_tp_group)
+                expert.down_proj = Linear1D_Row.from_native_module(expert.down_proj, self.moe_tp_group)
 
     @staticmethod
-    def from_native_module(module: Union["DeepseekMoE", "DeepseekMLP"], *args, **kwargs) -> "EPDeepseekMoE":
+    def from_native_module(
+        module,
+        tp_group: ProcessGroup,
+        moe_dp_group: ProcessGroup,
+        ep_group: ProcessGroup,
+        moe_tp_group: ProcessGroup,
+        *args,
+        **kwargs,
+    ) -> "EPDeepseekMoE":
         LazyInitContext.materialize(module)
         if module.__class__.__name__ == "DeepseekMLP":
             return module
         module.__class__ = EPDeepseekMoE
-        assert "ep_group" in kwargs, "You should pass ep_group in SubModuleReplacementDescription via shard_config!!"
-        module.setup_ep(kwargs["ep_group"])
+        module.setup_process_groups(tp_group, moe_dp_group, ep_group, moe_tp_group)
         return module
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -91,15 +127,24 @@ class EPDeepseekMoE(nn.Module):
         # [n0, n1, n2, n3] [m0, m1, m2, m3] -> [n0, n1, m0, m1] [n2, n3, m2, m3]
         dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
 
+        with torch.no_grad():
+            activate_experts = output_split_sizes[: self.num_experts_per_ep].clone()
+            for i in range(1, self.ep_size):
+                activate_experts += output_split_sizes[i * self.num_experts_per_ep : (i + 1) * self.num_experts_per_ep]
+            activate_experts = (activate_experts > 0).float()
+        dist.all_reduce(activate_experts, group=self.moe_dp_group)
+
         input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
         output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
         output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
-        output_states = MoeInGradScaler.apply(output_states, self.ep_size)
+        output_states = EPGradScalerIn.apply(output_states, self.ep_size)
 
         if output_states.size(0) > 0:
             if self.num_experts_per_ep == 1:
                 expert = self.experts[self.expert_start_idx]
+                output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
                 output_states = expert(output_states)
+                output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
             else:
                 output_states_splits = output_states.split(output_split_sizes.tolist())
                 output_states_list = []
@@ -107,10 +152,16 @@ class EPDeepseekMoE(nn.Module):
                     if split_states.size(0) == 0:  # no token routed to this experts
                         continue
                     expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+                    split_states = DPGradScalerIn.apply(
+                        split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+                    )
                     split_states = expert(split_states)
+                    split_states = DPGradScalerOut.apply(
+                        split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+                    )
                     output_states_list.append(split_states)
                 output_states = torch.cat(output_states_list)
-        output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
+        output_states = EPGradScalerOut.apply(output_states, self.ep_size)
         dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
         recover_token_idx = torch.empty_like(flat_topk_token_idx)
         recover_token_idx[flat_topk_token_idx] = torch.arange(
diff --git a/colossalai/shardformer/modeling/mixtral.py b/colossalai/shardformer/modeling/mixtral.py
index 86ef6c959..cfa7da6c0 100644
--- a/colossalai/shardformer/modeling/mixtral.py
+++ b/colossalai/shardformer/modeling/mixtral.py
@@ -116,8 +116,6 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
         input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
         output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
 
-        # TODO drop tokens to reduce tp group redundant communication
-
         output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
         # compute expert output
         output_states = EPGradScalerIn.apply(output_states, self.ep_size)
@@ -125,24 +123,24 @@ class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
             if self.num_experts_per_ep == 1:
                 # no need to split
                 expert = self.experts[self.expert_start_idx]
-                output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0].item())
+                output_states = DPGradScalerIn.apply(output_states, self.moe_dp_size, activate_experts[0])
                 output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
                 output_states = expert.w2(output_states)
-                output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0].item())
+                output_states = DPGradScalerOut.apply(output_states, self.moe_dp_size, activate_experts[0])
             else:
                 output_states_splits = output_states.split(output_split_sizes.tolist())
                 output_states_list = []
                 for i, split_states in enumerate(output_states_splits):
                     if split_states.size(0) == 0:
                         continue
-                    split_states = DPGradScalerIn.apply(
-                        split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
-                    )
                     expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
+                    split_states = DPGradScalerIn.apply(
+                        split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
+                    )
                     split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
                     split_states = expert.w2(split_states)
                     split_states = DPGradScalerOut.apply(
-                        split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep].item()
+                        split_states, self.moe_dp_size, activate_experts[i % self.num_experts_per_ep]
                     )
                     output_states_list.append(split_states)
                 output_states = torch.cat(output_states_list)
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index 1e0af031a..f2533da4b 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -161,7 +161,7 @@ _POLICY_LIST = {
         file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
     ),
     # Deepseek
-    "transformers_modules.modeling_deepseek.DeepSeekModel": PolicyLocation(
+    "transformers_modules.modeling_deepseek.DeepseekModel": PolicyLocation(
         file_name="deepseek", class_name="DeepseekModelPolicy"
     ),
     "transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py
index 8ebda357b..5a67d653d 100644
--- a/colossalai/shardformer/policies/deepseek.py
+++ b/colossalai/shardformer/policies/deepseek.py
@@ -7,6 +7,7 @@ from torch import Tensor
 from torch.nn import Module
 
 from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.layer.linear import Linear1D_Row
 from colossalai.shardformer.modeling.deepseek import DeepseekPipelineForwards, EPDeepseekMoE
 from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
 
@@ -39,16 +40,55 @@ class DeepseekPolicy(Policy):
             )
 
         if self.shard_config.enable_tensor_parallelism:
-            raise NotImplementedError("Tensor parallelism is not supported for Deepseek model now.")
+            # tensor parallelism for non-moe params
+            assert (
+                self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+            ), f"The number of attention heads must be divisible by tensor parallel size."
+            assert (
+                self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+            ), f"The number of key_value heads must be divisible by tensor parallel size."
+            decoder_attribute_replacement = {
+                "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+                "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+                "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
+                // self.shard_config.tensor_parallel_size,
+            }
 
-        if getattr(self.shard_config, "ep_group", None) is not None:
+            policy["DeepseekDecoderLayer"] = ModulePolicyDescription(
+                attribute_replacement=decoder_attribute_replacement,
+                sub_module_replacement=[
+                    SubModuleReplacementDescription(
+                        suffix="self_attn.q_proj",
+                        target_module=Linear1D_Col,
+                    ),
+                    SubModuleReplacementDescription(
+                        suffix="self_attn.k_proj",
+                        target_module=Linear1D_Col,
+                    ),
+                    SubModuleReplacementDescription(
+                        suffix="self_attn.v_proj",
+                        target_module=Linear1D_Col,
+                    ),
+                    SubModuleReplacementDescription(
+                        suffix="self_attn.o_proj",
+                        target_module=Linear1D_Row,
+                    ),
+                ],
+            )
+
+        if self.shard_config.ep_group:
             # expert parallel
             self.append_or_create_submodule_replacement(
                 description=[
                     SubModuleReplacementDescription(
                         suffix="mlp",
                         target_module=EPDeepseekMoE,
-                        kwargs={"ep_group": self.shard_config.ep_group},
+                        kwargs={
+                            "ep_group": self.shard_config.ep_group,
+                            "tp_group": self.shard_config.tensor_parallel_process_group,
+                            "moe_dp_group": self.shard_config.moe_dp_group,
+                            "moe_tp_group": self.shard_config.moe_tp_group,
+                        },
                     )
                 ],
                 policy=policy,
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index 4b77a167f..8905b5696 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -8,6 +8,7 @@ from torch.nn import Module
 from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralForCausalLM, MixtralModel
 
 from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
 from colossalai.shardformer.layer.linear import Linear1D_Row
 from colossalai.shardformer.modeling.mixtral import EPMixtralSparseMoeBlock, MixtralPipelineForwards
 from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -42,6 +43,13 @@ class MixtralPolicy(Policy):
                 "Mixtral dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
             )
 
+        embedding_cls = None
+        if self.shard_config.enable_tensor_parallelism:
+            embedding_cls = VocabParallelEmbedding1D
+        else:
+            if self.tie_weight:
+                embedding_cls = PaddingEmbedding
+
         if self.shard_config.enable_tensor_parallelism:
             # tensor parallelism for non-moe params
             assert (
@@ -76,13 +84,22 @@ class MixtralPolicy(Policy):
                         suffix="self_attn.o_proj",
                         target_module=Linear1D_Row,
                     ),
-                    SubModuleReplacementDescription(
+                    SubModuleReplacementDescription(  # or replicate?
                         suffix="block_sparse_moe.gate", target_module=Linear1D_Col, kwargs={"gather_output": True}
                     ),
                 ],
             )
 
-            # TODO shard vocab embedding
+        if embedding_cls is not None:
+            self.append_or_create_submodule_replacement(
+                description=SubModuleReplacementDescription(
+                    suffix="embed_tokens",
+                    target_module=embedding_cls,
+                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                ),
+                policy=policy,
+                target_key=MixtralModel,
+            )
 
         if self.shard_config.ep_group:
             # expert parallel
diff --git a/tests/test_moe/modelling/test_deepseek.py b/tests/test_moe/modelling/test_deepseek.py
new file mode 100644
index 000000000..42daea512
--- /dev/null
+++ b/tests/test_moe/modelling/test_deepseek.py
@@ -0,0 +1,133 @@
+import os
+import shutil
+from copy import deepcopy
+from typing import Tuple
+
+import pytest
+import torch
+import torch.distributed as dist
+from transformers import AutoConfig, AutoModel
+
+import colossalai
+from colossalai.booster.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import loose_close
+from tests.test_moe.test_moe_checkpoint import check_model_equal
+
+NUM_BATCH = 4
+NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
+HIDDEN_SIZE_PER_HEAD = 4
+NUM_HEADS = 4
+TOP_K = 1
+
+
+@parameterize("config", [(1, 1, 1)])
+def run_zero_with_original_model(config: Tuple[int, ...]):
+    stage, ep_size, tp_size = config
+    dtype = torch.float16
+
+    rank = torch.distributed.get_rank()
+    torch.cuda.set_device(dist.get_rank())
+
+    plugin = MoeHybridParallelPlugin(
+        pp_size=1,
+        tp_size=tp_size,
+        moe_tp_size=tp_size,
+        ep_size=ep_size,
+        zero_stage=stage,
+        overlap_communication=False,
+        initial_scale=1,
+        precision="fp32",
+    )
+    booster = Booster(plugin=plugin)
+
+    seed_all(10086)
+
+    config = AutoConfig.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
+    config.hidden_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS
+    config.intermediate_size = HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2
+    config.num_hidden_layers = 2
+    config.num_attention_heads = NUM_HEADS
+    config.num_key_value_heads = NUM_HEADS
+    config.n_routed_experts = NUM_EXPERTS
+    config.num_experts_per_tok = TOP_K
+    torch_model = AutoModel.from_config(config, trust_remote_code=True).cuda().to(dtype)
+
+    torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
+
+    zero_model = deepcopy(torch_model).to(dtype)
+    zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
+
+    zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+    # create different input
+    seed_all(1453 + rank)
+
+    torch_model.train()
+    zero_model.train()
+    for _ in range(2):
+        input_data = torch.rand(
+            NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+        ).cuda()
+        dist.all_reduce(input_data, group=plugin.tp_group)  # tp requires duplicate input
+
+        zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
+        zero_optimizer.backward(zero_output)
+        zero_optimizer.step()
+        zero_optimizer.zero_grad()
+        dist.all_reduce(zero_output)
+
+        all_inputs = [torch.empty_like(input_data) for _ in range(dist.get_world_size())]
+        dist.all_gather(all_inputs, input_data)
+
+        torch_output_sum = 0
+        for input_data_ in all_inputs:
+            torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
+            torch_output.backward()
+            torch_output_sum += torch_output.detach()
+        # avg dp grads
+        for p in torch_model.parameters():
+            if p.grad is not None:
+                p.grad /= dist.get_world_size()
+        torch_optimizer.step()
+        torch_optimizer.zero_grad()
+
+        loose_close(zero_output, torch_output_sum, dtype=dtype)
+
+    # use checkpoint to load sharded zero model
+    model_dir = "./test_deepseek"
+    if dist.get_rank() == 0:
+        os.makedirs(model_dir, exist_ok=True)
+
+    dist.barrier()
+
+    booster.save_model(zero_model, model_dir, shard=True)
+
+    dist.barrier()
+
+    saved_model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).cuda()
+    check_model_equal(torch_model, saved_model)
+
+    dist.barrier()
+    if dist.get_rank() == 0:
+        shutil.rmtree(model_dir)
+
+    print(f"{dist.get_rank()} test passed")
+
+
+def run_dist(rank, world_size, port):
+    colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+    run_zero_with_original_model()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [4])
+@rerun_if_address_is_in_use()
+def test_mistral(world_size):
+    spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+    test_mistral(world_size=4)
diff --git a/tests/test_moe/modelling/test_mixtral.py b/tests/test_moe/modelling/test_mixtral.py
index 8309bfb22..6e6f0b2b5 100644
--- a/tests/test_moe/modelling/test_mixtral.py
+++ b/tests/test_moe/modelling/test_mixtral.py
@@ -24,16 +24,6 @@ NUM_HEADS = 4
 TOP_K = 1
 
 
-def split_grad(grad, world_size):
-    with torch.no_grad():
-        grad = grad.clone().detach().flatten()
-        padding_size = (world_size - grad.numel() % world_size) % world_size
-        if padding_size > 0:
-            grad = torch.nn.functional.pad(grad, [0, padding_size])
-        splited_grad = grad.split(grad.numel() // world_size)
-    return splited_grad
-
-
 @parameterize("config", [(1, 1, 4), (1, 2, 2), (1, 4, 1)])
 def run_zero_with_original_model(config: Tuple[int, ...]):
     stage, ep_size, tp_size = config
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index 6f3c5b299..4bcf701de 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -16,6 +16,7 @@ from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParall
 from colossalai.tensor.moe_tensor.api import is_moe_tensor
 from colossalai.testing import parameterize, spawn
 from colossalai.testing.utils import spawn
+from tests.test_moe.moe_utils import loose_close
 
 tokens, n_experts = 7, 4
 hidden_size = 8
@@ -25,7 +26,7 @@ top_k = 2
 def check_model_equal(model1, model2):
     assert set(model1.state_dict().keys()) == set(model2.state_dict().keys())
     for i, ((name, p1), p2) in enumerate(zip(model1.named_parameters(), model2.parameters())):
-        if not torch.equal(p1.half(), p2.half()):
+        if loose_close(p1, p2, p1.dtype):
             print(f"Model parameter {name} is not equal. is_moe_tensor: {is_moe_tensor(p1)}")
             raise AssertionError(f"Model parameter {name} is not equal")
 
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
index e944a8c0a..29881c9ab 100644
--- a/tests/test_moe/test_moe_ep_tp.py
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -21,16 +21,6 @@ NUM_HEADS = 4
 TOP_K = 2
 
 
-def split_grad(grad, world_size):
-    with torch.no_grad():
-        grad = grad.clone().detach().flatten()
-        padding_size = (world_size - grad.numel() % world_size) % world_size
-        if padding_size > 0:
-            grad = torch.nn.functional.pad(grad, [0, padding_size])
-        splited_grad = grad.split(grad.numel() // world_size)
-    return splited_grad
-
-
 @parameterize("stage", [1])
 @parameterize("ep_size", [1, 2, 4])
 def run_zero_with_original_model(stage: int, ep_size: int):
diff --git a/tests/test_moe/test_moe_ep_zero.py b/tests/test_moe/test_moe_ep_zero.py
index c5adaad06..40e3bacb3 100644
--- a/tests/test_moe/test_moe_ep_zero.py
+++ b/tests/test_moe/test_moe_ep_zero.py
@@ -14,21 +14,12 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 from colossalai.testing.random import seed_all
 from tests.test_moe.moe_utils import loose_close
 
-NUM_BATCH=4
+NUM_BATCH = 4
 NUM_TOK_PER_BATCH, NUM_EXPERTS = 7, 4
 HIDDEN_SIZE_PER_HEAD = 4
-NUM_HEADS=2
+NUM_HEADS = 2
 TOP_K = 1
 
-def split_grad(grad, world_size):
-    with torch.no_grad():
-        grad = grad.clone().detach().flatten()
-        padding_size = (world_size - grad.numel() % world_size) % world_size
-        if padding_size > 0:
-            grad = torch.nn.functional.pad(grad, [0, padding_size])
-        splited_grad = grad.split(grad.numel() // world_size)
-    return splited_grad
-
 
 @parameterize("stage", [1])
 @parameterize("ep_size", [1, 2, 4])
@@ -39,12 +30,7 @@ def run_zero_with_original_model(stage: int, ep_size: int):
     torch.cuda.set_device(dist.get_rank())
 
     plugin = MoeHybridParallelPlugin(
-        pp_size=1,
-        tp_size=1,
-        ep_size=ep_size,
-        zero_stage=stage,
-        overlap_communication=False,
-        initial_scale=1
+        pp_size=1, tp_size=1, ep_size=ep_size, zero_stage=stage, overlap_communication=False, initial_scale=1
     )
     booster = Booster(plugin=plugin)
 
@@ -81,7 +67,9 @@ def run_zero_with_original_model(stage: int, ep_size: int):
     zero_model.train()
     for _ in range(2):
         # zero-dp forward
-        input_data = torch.rand(NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True).cuda()
+        input_data = torch.rand(
+            NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
+        ).cuda()
         zero_output = zero_model(inputs_embeds=input_data.to(dtype)).last_hidden_state.mean()
         # zero-dp backward
         zero_optimizer.backward(zero_output)