diff --git a/.gitignore b/.gitignore
index 16f764c1b..533450a7c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -164,3 +164,4 @@ coverage.xml
 applications/ColossalChat/logs
 applications/ColossalChat/tests/logs
 applications/ColossalChat/wandb
+applications/ColossalChat/model
diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py
index 027acc2e7..c6ae7be2d 100644
--- a/applications/ColossalChat/coati/distributed/consumer.py
+++ b/applications/ColossalChat/coati/distributed/consumer.py
@@ -54,7 +54,6 @@ class BaseConsumer:
 
         self.model_config = model_config
         self.plugin_config = plugin_config
-        assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
 
         self.device = get_current_device()
         self.lr_scheduler = None
@@ -95,7 +94,6 @@ class BaseConsumer:
                     i = 0
                     for _ in range(self.num_recv_per_update):
                         # receive data from producers
-
                         for r in range(self.num_producers):
                             print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
                             self.buffer.extend(
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index 4174f9651..a282439cb 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -39,6 +39,7 @@ class GRPOConsumer(BaseConsumer):
         use_wandb=True,
         generate_config=None,
         training_config={},
+        project_name=None,
     ):
         super().__init__(
             num_producers,
@@ -69,6 +70,7 @@ class GRPOConsumer(BaseConsumer):
         self.accum_count = 0
         self.generate_config = generate_config
         self.training_config = training_config
+        self.project_name = project_name
 
         # Reference model is initialized from policy model.
         self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -94,9 +96,7 @@ class GRPOConsumer(BaseConsumer):
 
         self.policy_loss_fn = PolicyLoss()
         self.global_step = 0
-        if use_wandb and self.rank == 0:
-            name = f"{generate_config['backend']}_bs_{self.batch_size*self.world_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
-            self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True, dir="./wandb", name=name)
+        self.use_wandb = use_wandb
 
         self.lr_scheduler = CosineAnnealingWarmupLR(
             optimizer=self.optimizer,
@@ -107,10 +107,19 @@ class GRPOConsumer(BaseConsumer):
 
     def setup(self):
         super().setup()
+        if self.use_wandb and (
+            (not self.plugin.pp_size > 1 and self.rank == 0)
+            or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage())
+        ):
+            # Initialize wandb.
+            name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
+            self.wandb_run = wandb.init(project=self.project_name, sync_tensorboard=True, dir="./wandb", name=name)
+
         self.policy_model, self.optimizer, _, _, self.lr_scheduler = self.booster.boost(
             self.policy_model, self.optimizer, lr_scheduler=self.lr_scheduler
         )
         self.reference_model, *_ = self.booster.boost(self.reference_model)
+        self.plugin.logger.set_level("ERROR")
 
     def step(self, step_idx: int, **kwargs) -> Optional[float]:
         """
@@ -168,6 +177,7 @@ class GRPOConsumer(BaseConsumer):
                 ).repeat_interleave(self.num_generations, dim=0)
             )
             mean_kl, mean_loss = [], []
+
             for forward_micro_batch_start in range(0, data["input_ids"].size(0), forward_batch_size):
                 input_ids_forward_micro_batch = data["input_ids"][
                     forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
@@ -186,112 +196,210 @@ class GRPOConsumer(BaseConsumer):
                 advantages_forward_micro_batch = advantages[
                     forward_micro_batch_start : forward_micro_batch_start + forward_batch_size
                 ]
-                policy_model_logits = self.policy_model(
-                    input_ids=input_ids_forward_micro_batch,
-                    attention_mask=attention_mask_forward_micro_batch,
-                ).logits
-                action_log_probs = calc_action_log_probs(
-                    policy_model_logits / self.generate_config["temperature"],
-                    input_ids_forward_micro_batch,
-                    num_action,
-                    self.plugin.shard_config,
-                )
 
-                with torch.no_grad():
-                    reference_model_logits = self.reference_model(
+                if self.plugin.pp_size > 1:
+                    # Support training with PP.
+
+                    with torch.no_grad():
+                        reference_model_outputs = self.booster.execute_pipeline(
+                            iter(
+                                [
+                                    {
+                                        "input_ids": input_ids_forward_micro_batch,
+                                        "attention_mask": attention_mask_forward_micro_batch,
+                                    }
+                                ]
+                            ),
+                            self.reference_model,
+                            criterion=lambda outputs, inputs: torch.tensor(
+                                [0.0], device=action_mask.device
+                            ),  # dummy criterion
+                            optimizer=None,
+                            return_loss=False,
+                            return_outputs=True,
+                        )
+
+                    if self.booster.plugin.stage_manager.is_last_stage():
+                        reference_model_logits = reference_model_outputs["outputs"]["logits"]
+                        reference_action_log_probs = calc_action_log_probs(
+                            reference_model_logits / self.generate_config["temperature"],
+                            input_ids_forward_micro_batch,
+                            num_action,
+                            self.plugin.shard_config,
+                        )
+                    else:
+                        # Dummy reference logprobs for data iterator.
+                        reference_action_log_probs = None
+
+                    data_policy_forward = {
+                        "input_ids": input_ids_forward_micro_batch,
+                        "attention_mask": attention_mask_forward_micro_batch,
+                        "action_mask": action_mask_forward_micro_batch,
+                        "reference_action_log_probs": reference_action_log_probs,
+                        "advantages": advantages_forward_micro_batch,
+                        "loss_mask": loss_mask_forward_micro_batch,
+                        "source": self.rank,
+                    }
+
+                    kl = []
+
+                    def _criterion(outputs, inputs):
+                        action_logits = outputs.logits
+                        action_log_probs = calc_action_log_probs(
+                            action_logits / self.generate_config["temperature"],
+                            inputs["input_ids"],
+                            num_action,
+                            self.plugin.shard_config,
+                        )
+                        per_token_kl = (
+                            torch.exp(inputs["reference_action_log_probs"] - action_log_probs)
+                            - (inputs["reference_action_log_probs"] - action_log_probs)
+                            - 1
+                        )
+                        appox_kl = torch.sum(per_token_kl * inputs["action_mask"], dim=-1) / torch.sum(
+                            inputs["action_mask"], dim=-1
+                        )
+                        kl.append(appox_kl.mean())
+                        loss, skip_update, _ = self.policy_loss_fn(
+                            action_log_probs,
+                            action_log_probs,
+                            inputs["advantages"].repeat_interleave(action_log_probs.size(-1), dim=-1),
+                            per_token_kl,
+                            inputs["action_mask"],
+                            loss_mask=inputs["loss_mask"],
+                        )
+                        return loss
+
+                    policy_model_outputs = self.booster.execute_pipeline(
+                        iter([data_policy_forward]),
+                        self.policy_model,
+                        criterion=_criterion,
+                        optimizer=self.optimizer,
+                        return_loss=True,
+                        return_outputs=True,
+                    )
+                    loss = policy_model_outputs["loss"]
+
+                    if self.booster.plugin.stage_manager.is_last_stage():
+                        if len(kl) > 0:
+                            kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin)
+                            mean_kl.append(kl)
+                        loss = all_reduce_mean(loss, self.plugin)
+                        mean_loss.append(loss.data)
+                else:
+
+                    policy_model_logits = self.policy_model(
                         input_ids=input_ids_forward_micro_batch,
                         attention_mask=attention_mask_forward_micro_batch,
                     ).logits
-                reference_action_log_probs = calc_action_log_probs(
-                    reference_model_logits / self.generate_config["temperature"],
-                    input_ids_forward_micro_batch,
-                    num_action,
-                    self.plugin.shard_config,
-                )
+                    action_log_probs = calc_action_log_probs(
+                        policy_model_logits / self.generate_config["temperature"],
+                        input_ids_forward_micro_batch,
+                        num_action,
+                        self.plugin.shard_config,
+                    )
 
-                per_token_kl = (
-                    torch.exp(reference_action_log_probs - action_log_probs)
-                    - (reference_action_log_probs - action_log_probs)
-                    - 1
-                )
-                kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
-                    action_mask_forward_micro_batch, dim=-1
-                )
+                    with torch.no_grad():
+                        reference_model_logits = self.reference_model(
+                            input_ids=input_ids_forward_micro_batch,
+                            attention_mask=attention_mask_forward_micro_batch,
+                        ).logits
+                    reference_action_log_probs = calc_action_log_probs(
+                        reference_model_logits / self.generate_config["temperature"],
+                        input_ids_forward_micro_batch,
+                        num_action,
+                        self.plugin.shard_config,
+                    )
+                    per_token_kl = (
+                        torch.exp(reference_action_log_probs - action_log_probs)
+                        - (reference_action_log_probs - action_log_probs)
+                        - 1
+                    )
+                    kl = torch.sum(per_token_kl * action_mask_forward_micro_batch, dim=-1) / torch.sum(
+                        action_mask_forward_micro_batch, dim=-1
+                    )
 
-                loss, skip_update, _ = self.policy_loss_fn(
-                    action_log_probs,
-                    old_action_log_probs,
-                    advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
-                    per_token_kl,
-                    action_mask_forward_micro_batch,
-                    loss_mask=loss_mask_forward_micro_batch,
-                )
+                    loss, skip_update, _ = self.policy_loss_fn(
+                        action_log_probs,
+                        old_action_log_probs,
+                        advantages_forward_micro_batch.repeat_interleave(action_log_probs.size(-1), dim=-1),
+                        per_token_kl,
+                        action_mask_forward_micro_batch,
+                        loss_mask=loss_mask_forward_micro_batch,
+                    )
 
-                if not skip_update:
-                    self.booster.backward(loss, self.optimizer)
-                loss = all_reduce_mean(loss, self.plugin)
-                kl = all_reduce_mean(kl.mean(), self.plugin)
-                # Calculate accumulate value.
-                mean_kl.append(kl.data)
-                mean_loss.append(loss.data)
-
-            reward = all_reduce_mean(reward.mean(), self.plugin)
-            format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
-            acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
-            advantages = all_reduce_mean(advantages.mean(), self.plugin)
-            response_length = all_reduce_mean(response_length.mean(), self.plugin)
-            self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
-            self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
-            self.accum_reward.add_(reward.data)
-            self.accum_format_reward.add_(format_reward.data)
-            self.accum_acc_reward.add_(acc_reward.data)
-            self.accum_advantages.add_(advantages.data)
-            self.accum_response_length.add_(response_length.data)
-            self.accum_count += 1
+                    if not skip_update:
+                        self.booster.backward(loss, self.optimizer)
+                    loss = all_reduce_mean(loss, self.plugin)
+                    kl = all_reduce_mean(kl.mean(), self.plugin)
+                    # Calculate accumulate value.
+                    mean_kl.append(kl.data)
+                    mean_loss.append(loss.data)
+            if not self.plugin.pp_size > 1 or (
+                self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
+            ):
+                reward = all_reduce_mean(reward.mean(), self.plugin)
+                format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
+                acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
+                advantages = all_reduce_mean(advantages.mean(), self.plugin)
+                response_length = all_reduce_mean(response_length.mean(), self.plugin)
+                self.accum_loss.add_(sum(mean_loss) / len(mean_loss))
+                self.accum_kl.add_(sum(mean_kl) / len(mean_kl))
+                self.accum_reward.add_(reward.data)
+                self.accum_format_reward.add_(format_reward.data)
+                self.accum_acc_reward.add_(acc_reward.data)
+                self.accum_advantages.add_(advantages.data)
+                self.accum_response_length.add_(response_length.data)
+                self.accum_count += 1
         if need_update:
             self.optimizer.step()
             self.optimizer.zero_grad()
-            loss_scalar = self.accum_loss.item()
-            if self.rank == 0:
-                print(
-                    "Loss:",
-                    self.accum_loss.item() / self.accum_count,
-                    "\nReward:",
-                    self.accum_reward.item() / self.accum_count,
-                    "\nFormat Reward:",
-                    self.accum_format_reward.item() / self.accum_count,
-                    "\nAcc Reward:",
-                    self.accum_acc_reward.item() / self.accum_count,
-                    "\nKL:",
-                    self.accum_kl.item() / self.accum_count,
-                    "\nAdvantages:",
-                    self.accum_advantages.item() / self.accum_count,
-                    "\nResponse Length:",
-                    self.accum_response_length.item() / self.accum_count,
-                )
-                self.wandb_run.log(
-                    {
-                        "metrics/reward": self.accum_reward.item() / self.accum_count,
-                        "metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
-                        "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
-                        "metrics/response_length": self.accum_response_length.item() / self.accum_count,
-                        "train/loss": self.accum_loss.item() / self.accum_count,
-                        "train/kl": self.accum_kl.item() / self.accum_count,
-                        "train/advantages": self.accum_advantages.item() / self.accum_count,
-                        "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
-                        "rollout/temperature": data["temperature"].cpu().numpy()[0][0],
-                    }
-                )
-            self.accum_loss.zero_()
-            self.accum_reward.zero_()
-            self.accum_acc_reward.zero_()
-            self.accum_format_reward.zero_()
-            self.accum_kl.zero_()
-            self.accum_advantages.zero_()
-            self.accum_response_length.zero_()
+            if not self.plugin.pp_size > 1 or (
+                self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
+            ):
+                loss_scalar = self.accum_loss.item()
+                if (not self.plugin.pp_size > 1 and self.rank == 0) or (
+                    self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage()
+                ):
+                    print(
+                        "Loss:",
+                        self.accum_loss.item() / self.accum_count,
+                        "\nReward:",
+                        self.accum_reward.item() / self.accum_count,
+                        "\nFormat Reward:",
+                        self.accum_format_reward.item() / self.accum_count,
+                        "\nAcc Reward:",
+                        self.accum_acc_reward.item() / self.accum_count,
+                        "\nKL:",
+                        self.accum_kl.item() / self.accum_count,
+                        "\nAdvantages:",
+                        self.accum_advantages.item() / self.accum_count,
+                        "\nResponse Length:",
+                        self.accum_response_length.item() / self.accum_count,
+                    )
+                    self.wandb_run.log(
+                        {
+                            "metrics/reward": self.accum_reward.item() / self.accum_count,
+                            "metrics/format_reward": self.accum_format_reward.item() / self.accum_count,
+                            "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count,
+                            "metrics/response_length": self.accum_response_length.item() / self.accum_count,
+                            "train/loss": self.accum_loss.item() / self.accum_count,
+                            "train/kl": self.accum_kl.item() / self.accum_count,
+                            "train/advantages": self.accum_advantages.item() / self.accum_count,
+                            "train/learning_rate": self.lr_scheduler.get_last_lr()[0],
+                            "rollout/temperature": data["temperature"].cpu().numpy()[0][0],
+                        }
+                    )
+                self.accum_loss.zero_()
+                self.accum_reward.zero_()
+                self.accum_acc_reward.zero_()
+                self.accum_format_reward.zero_()
+                self.accum_kl.zero_()
+                self.accum_advantages.zero_()
+                self.accum_response_length.zero_()
 
-            self.accum_count = 0
-            return loss_scalar
+                self.accum_count = 0
+                return loss_scalar
 
     def state_dict(self):
         self.policy_model._force_wait_all_gather()
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index ba5d3a9d4..699d90a8c 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -47,6 +47,7 @@ def launch_distributed(
     master_addr: str = "localhost",
     master_port: int = 29500,
     core_algo: str = "GRPO",
+    project_name: Optional[str] = None,
 ):
 
     if core_algo not in ALGO_MAP:
@@ -108,6 +109,7 @@ def launch_distributed(
                 "train_microbatch_size": train_microbatch_size,
             },
             num_generations=num_generations,
+            project_name=project_name,
         )
         procs.append(consumer)
     ray.get([p.setup.remote() for p in procs])
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 4a4a4c340..f87f12ed2 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -10,13 +10,44 @@ if __name__ == "__main__":
     parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
     parser.add_argument("-t", "--num-trainers", type=int, default=2)
     parser.add_argument("-i", "--num-inferencer", type=int, default=2)
-    parser.add_argument("-g", "--num-generations", type=int, default=8)
-    parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
-    parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
-    parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
-    parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
-    parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
-    parser.add_argument("-b", "--backend", type=str, default="transformers")
+    parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
+    parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
+    parser.add_argument(
+        "-ibs",
+        "--inference-batch-size",
+        type=int,
+        default=64,
+        help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
+    )
+    parser.add_argument(
+        "-imbs",
+        "--inference-microbatch-size",
+        type=int,
+        default=8,
+        help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
+    )
+    parser.add_argument(
+        "-tbs",
+        "--train-batch-size",
+        type=int,
+        default=32,
+        help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
+    )
+    parser.add_argument(
+        "-tMbs",
+        "--train-minibatch-size",
+        type=int,
+        default=1,
+        help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
+    )
+    parser.add_argument(
+        "-tmbs",
+        "--train-microbatch-size",
+        type=int,
+        default=2,
+        help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
+    )
+    parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
     parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
     args = parser.parse_args()
 
@@ -29,11 +60,7 @@ if __name__ == "__main__":
     ray.init(address="local", namespace="ray-example")
 
     inference_model_config = dict(path=args.model)
-    train_model_config = dict(
-        path=args.model,
-        # use_flash_attention_2=True,
-        # use_cache=False
-    )
+    train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
     generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
 
     if args.backend == "transformers":
@@ -91,9 +118,17 @@ if __name__ == "__main__":
         generate_config=generate_config,
         num_generations=args.num_generations,
         train_model_config=train_model_config,
-        plugin_config={},
+        # plugin_config={}, # for zero
+        plugin_config={
+            "pp_size": 2,
+            "tp_size": 1,
+            "microbatch_size": args.train_microbatch_size // 2,
+            "zero_stage": 0,
+            "max_norm": 1.0,
+        },  # for pp
         inference_backend=args.backend,
         master_addr="localhost",
-        master_port=29505,
+        master_port=29506,
         core_algo=args.algo,
+        project_name=args.project,
     )
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 1684fd702..74349091b 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1411,8 +1411,10 @@ class HybridParallelPlugin(PipelinePluginBase):
             )
 
         # run with gradients accumulation
-        if model.require_grad_sync == False or (
-            isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
+        if (
+            not torch.is_grad_enabled()
+            or model.require_grad_sync == False
+            or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
         ):
             return outputs
 
diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py
index 71e3557fe..27571309e 100644
--- a/colossalai/shardformer/modeling/qwen2.py
+++ b/colossalai/shardformer/modeling/qwen2.py
@@ -284,6 +284,7 @@ class Qwen2PipelineForwards:
         hidden_states: Optional[torch.FloatTensor] = None,
         stage_index: Optional[List[int]] = None,
         shard_config: ShardConfig = None,
+        **kwargs,
     ):
         r"""
         Args: