From 34ca324b0d193623c89d8aea1aedb3c00ac2f654 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Wed, 22 Feb 2023 10:00:26 +0800 Subject: [PATCH] [chatgpt] Support saving ckpt in examples (#2846) * [chatgpt]fix train_rm bug with lora * [chatgpt]support colossalai strategy to train rm * fix pre-commit * fix pre-commit 2 * [chatgpt]fix rm eval typo * fix rm eval * fix pre commit * add support of saving ckpt in examples * fix single-gpu save --- applications/ChatGPT/examples/train_dummy.py | 7 +++++++ applications/ChatGPT/examples/train_prompts.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py index f98b4792d..a27d77a50 100644 --- a/applications/ChatGPT/examples/train_dummy.py +++ b/applications/ChatGPT/examples/train_dummy.py @@ -97,6 +97,13 @@ def main(args): max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + # save model checkpoint after fitting on only rank0 + strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True) + # save optimizer checkpoint on all ranks + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py index e79b2acf1..53aa150a0 100644 --- a/applications/ChatGPT/examples/train_prompts.py +++ b/applications/ChatGPT/examples/train_prompts.py @@ -2,6 +2,7 @@ import argparse from copy import deepcopy import pandas as pd +import torch from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel from chatgpt.trainer import PPOTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy @@ -95,6 +96,12 @@ def main(args): num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) + # save model checkpoint after fitting on only rank0 + strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True) + # save optimizer checkpoint on all ranks + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) if __name__ == '__main__':