mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 03:21:47 +00:00
fix checkpoint naming; add num_epoch parameter (#6277)
This commit is contained in:
parent
26d859f68e
commit
38008858e4
@ -130,7 +130,7 @@ class BaseConsumer:
|
||||
if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode:
|
||||
if self.rank == 0:
|
||||
print(f"Start saving policy model at step {step + 1}.")
|
||||
save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}")
|
||||
save_path = os.path.join(self.save_dir, f"modeling-episode-{episode}-step-{step + 1}")
|
||||
self.booster.save_model(self.policy_model, save_path, shard=True)
|
||||
if self.rank == 0:
|
||||
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
|
||||
|
@ -10,6 +10,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
|
||||
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
|
||||
|
||||
# Distributed training parameters
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
@ -192,7 +193,7 @@ if __name__ == "__main__":
|
||||
num_producers=args.num_inferencer,
|
||||
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", 1),
|
||||
num_consumer_procs=args.num_trainers,
|
||||
num_episodes=1,
|
||||
num_episodes=args.num_episodes,
|
||||
inference_batch_size=args.inference_batch_size,
|
||||
inference_microbatch_size=args.inference_microbatch_size,
|
||||
train_batch_size=args.train_batch_size,
|
||||
|
Loading…
Reference in New Issue
Block a user