mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 16:28:10 +00:00
[chatgpt] Support saving ckpt in examples (#2846)
* [chatgpt]fix train_rm bug with lora * [chatgpt]support colossalai strategy to train rm * fix pre-commit * fix pre-commit 2 * [chatgpt]fix rm eval typo * fix rm eval * fix pre commit * add support of saving ckpt in examples * fix single-gpu save
This commit is contained in:
parent
597914317b
commit
34ca324b0d
@ -97,6 +97,13 @@ def main(args):
|
|||||||
max_timesteps=args.max_timesteps,
|
max_timesteps=args.max_timesteps,
|
||||||
update_timesteps=args.update_timesteps)
|
update_timesteps=args.update_timesteps)
|
||||||
|
|
||||||
|
# save model checkpoint after fitting on only rank0
|
||||||
|
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
|
||||||
|
# save optimizer checkpoint on all ranks
|
||||||
|
strategy.save_optimizer(actor_optim,
|
||||||
|
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
|
||||||
|
only_rank0=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -2,6 +2,7 @@ import argparse
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import torch
|
||||||
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
from chatgpt.nn import BLOOMActor, BLOOMCritic, GPTActor, GPTCritic, OPTActor, OPTCritic, RewardModel
|
||||||
from chatgpt.trainer import PPOTrainer
|
from chatgpt.trainer import PPOTrainer
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
@ -95,6 +96,12 @@ def main(args):
|
|||||||
num_episodes=args.num_episodes,
|
num_episodes=args.num_episodes,
|
||||||
max_timesteps=args.max_timesteps,
|
max_timesteps=args.max_timesteps,
|
||||||
update_timesteps=args.update_timesteps)
|
update_timesteps=args.update_timesteps)
|
||||||
|
# save model checkpoint after fitting on only rank0
|
||||||
|
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
|
||||||
|
# save optimizer checkpoint on all ranks
|
||||||
|
strategy.save_optimizer(actor_optim,
|
||||||
|
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||||
|
only_rank0=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user