From 9df4a8047a14aaabade148d69b583f5c311cd09a Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Sat, 26 Apr 2025 13:15:12 +0800 Subject: [PATCH] fix checkpoint naming; add num_epoch parameter --- applications/ColossalChat/coati/distributed/consumer.py | 2 +- applications/ColossalChat/rl_example.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 28d83fa40..47f08cc0b 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -130,7 +130,7 @@ class BaseConsumer: if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: if self.rank == 0: print(f"Start saving policy model at step {step + 1}.") - save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}") + save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}") self.booster.save_model(self.policy_model, save_path, shard=True) if self.rank == 0: print(f"Saved model checkpoint at step {step + 1} in folder {save_path}") diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 9c0ec7922..b20e9dc38 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -10,6 +10,7 @@ if __name__ == "__main__": parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B") parser.add_argument("-d", "--dataset", type=str, default="data.jsonl") parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.") + parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.") # Distributed training parameters parser.add_argument("-t", "--num-trainers", type=int, default=2) @@ -192,7 +193,7 @@ if __name__ == "__main__": num_producers=args.num_inferencer, num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1), num_consumer_procs=args.num_trainers, - num_episodes=1, + num_episodes=args.num_episodes, inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size,