diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py
index 28d83fa40..47f08cc0b 100644
--- a/applications/ColossalChat/coati/distributed/consumer.py
+++ b/applications/ColossalChat/coati/distributed/consumer.py
@@ -130,7 +130,7 @@ class BaseConsumer:
                     if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
                         if self.rank == 0:
                             print(f"Start saving policy model at step {step + 1}.")
-                        save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
+                        save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
                         self.booster.save_model(self.policy_model, save_path, shard=True)
                         if self.rank == 0:
                             print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 9c0ec7922..b20e9dc38 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -10,6 +10,7 @@ if __name__ == "__main__":
     parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
     parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
     parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
+    parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
 
     # Distributed training parameters
     parser.add_argument("-t", "--num-trainers", type=int, default=2)
@@ -192,7 +193,7 @@ if __name__ == "__main__":
         num_producers=args.num_inferencer,
         num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
         num_consumer_procs=args.num_trainers,
-        num_episodes=1,
+        num_episodes=args.num_episodes,
         inference_batch_size=args.inference_batch_size,
         inference_microbatch_size=args.inference_microbatch_size,
         train_batch_size=args.train_batch_size,