mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[chatgpt]support opt & gpt for rm training (#2876)
This commit is contained in:
@@ -3,12 +3,13 @@ import argparse
|
||||
import loralib as lora
|
||||
import torch
|
||||
from chatgpt.dataset import RewardDataset
|
||||
from chatgpt.nn import BLOOMRM
|
||||
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
|
||||
from chatgpt.trainer import RewardModelTrainer
|
||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from transformers import BloomTokenizerFast
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
@@ -27,11 +28,30 @@ def train(args):
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
with strategy.model_init_context():
|
||||
model = BLOOMRM(pretrained=args.pretrain).cuda()
|
||||
max_len = 1024
|
||||
if args.model == 'bloom':
|
||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'opt':
|
||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
max_len = 512
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
@@ -58,10 +78,10 @@ def train(args):
|
||||
|
||||
trainer.fit(use_lora=args.lora_rank)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
|
||||
else:
|
||||
torch.save(trainer.model, args.save_path)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -69,6 +89,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
||||
|
@@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2
|
||||
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2
|
||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2
|
||||
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
|
||||
|
Reference in New Issue
Block a user