diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 79beb2a2d..b7b865b26 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -119,7 +119,7 @@ class BaseConsumer: assert len(self.buffer) == 0 if self.lr_scheduler is not None: self.lr_scheduler.step() - if (step + 1) % self.save_interval == 0: + if (step + 1) % self.save_interval == 0 or (step + 1) == self.num_update_per_episode: if self.rank == 0: print(f"Start saving policy model at step {step + 1}.") save_path = os.path.join(self.save_dir, f"modeling-step-{step + 1}") diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 317446695..f42a660b7 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -107,7 +107,7 @@ if __name__ == "__main__": num_producers=args.num_inferencer, num_proc_per_producer=1, num_consumer_procs=args.num_trainers, - num_episodes=10, + num_episodes=1, inference_batch_size=args.inference_batch_size, inference_microbatch_size=args.inference_microbatch_size, train_batch_size=args.train_batch_size,