ColossalAI/applications/ChatGPT/examples/train_reward_model.py
BlueRum 613efebc5c
[chatgpt] support colossalai strategy to train rm (#2742)
* [chatgpt]fix train_rm bug with lora

* [chatgpt]support colossalai strategy to train rm

* fix pre-commit

* fix pre-commit 2
2023-02-16 11:24:07 +08:00

79 lines
3.0 KiB
Python

import argparse
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from torch.optim import Adam
from transformers import BloomTokenizerFast
from colossalai.nn.optimizer import HybridAdam
def train(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
model = BLOOMRM(pretrained=args.pretrain).cuda()
max_len = 1024
# configure optimizer
if args.strategy.startswith('colossalai'):
optim = HybridAdam(model.parameters(), lr=5e-5)
else:
optim = Adam(model.parameters(), lr=5e-5)
# prepare for data and dataset
data = load_dataset(args.dataset)
train_data = data["train"].select(range(100))
eval_data = data['test'].select(range(5))
train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
# batch_size here is expected to be C(k,2), k means # response of each prompt
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
batch_size=args.batch_size,
max_epochs=args.max_epochs)
trainer.fit(use_lora=args.lora_rank)
if args.lora_rank > 0:
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
else:
torch.save(trainer.model, args.save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
parser.add_argument('--max_epochs', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()
train(args)