diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py index f98b4792d..a27d77a50 100644 --- a/applications/ChatGPT/examples/train_dummy.py +++ b/applications/ChatGPT/examples/train_dummy.py @@ -97,6 +97,13 @@ def main(args): max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + # save model checkpoint after fitting on only rank0 + strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True) + # save optimizer checkpoint on all ranks + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py index e79b2acf1..53aa150a0 100644 --- a/applications/ChatGPT/examples/train_prompts.py +++ b/applications/ChatGPT/examples/train_prompts.py @@ -2,6 +2,7 @@ import argparse from copy import deepcopy import pandas as pd +import torch from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.trainer import PPOTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy @@ -95,6 +96,12 @@ def main(args): num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + # save model checkpoint after fitting on only rank0 + strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True) + # save optimizer checkpoint on all ranks + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) if __name__ == '__main__':