ColossalAI/applications/Chat/examples/train_reward_model.py
Camille Zhong 30412866e0
[chatgpt] add pre-trained model RoBERTa for RLHF stage 2 & 3 (#3223)
* Add RoBERTa for RLHF Stage 2 & 3 (test)

RoBERTa for RLHF Stage 2 & 3 (still in testing)

* Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)"

This reverts commit 06741d894d.

* Add RoBERTa for RLHF stage 2 & 3

1. add roberta folder under model folder
2. add  roberta option in train_reward_model.py
3. add some test in testci

* add test for reward model training

* Update test_ci.sh

* Revert "Update test_ci.sh"

This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a.

* Add RoBERTa for RLHF Stage 2 & 3 (test)

RoBERTa for RLHF Stage 2 & 3 (still in testing)

* Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)"

This reverts commit 06741d894d.

* Add RoBERTa for RLHF stage 2 & 3

1. add roberta folder under model folder
2. add  roberta option in train_reward_model.py
3. add some test in testci

* Update test_ci.sh

* Revert "Update test_ci.sh"

This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a.

* update roberta with coati
2023-04-03 10:11:03 +08:00

166 lines
7.2 KiB
Python

import argparse
from random import randint
import loralib as lora
import torch
from coati.dataset import HhRlhfDataset, RmStaticDataset
from coati.models import LogExpLoss, LogSigLoss
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMRM
from coati.models.deberta import DebertaRM
from coati.models.gpt import GPTRM
from coati.models.llama import LlamaRM
from coati.models.opt import OPTRM
from coati.models.roberta import RoBERTaRM
from coati.trainer import RewardModelTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
from datasets import load_dataset
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
def train(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
if args.model == 'bloom':
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'opt':
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'deberta':
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'llama':
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'roberta':
model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')
if args.model_path is not None:
state_dict = torch.load(args.model_path)
model.load_state_dict(state_dict)
model = model.to(torch.float16)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'deberta':
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
elif args.model == 'llama':
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
else:
raise ValueError(f'Unsupported model "{args.model}"')
max_len = args.max_len
if args.model == 'llama':
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
else:
tokenizer.pad_token = tokenizer.eos_token
# configure optimizer
if args.strategy.startswith('colossalai'):
optim = HybridAdam(model.parameters(), lr=5e-6)
else:
optim = Adam(model.parameters(), lr=5e-6)
# configure loss function
if args.loss_fn == 'log_sig':
loss_fn = LogSigLoss()
elif args.loss_fn == 'log_exp':
loss_fn = LogExpLoss()
else:
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
# prepare for data and dataset
if args.subset is not None:
data = load_dataset(args.dataset, data_dir=args.subset)
else:
data = load_dataset(args.dataset)
if args.test:
train_data = data['train'].select(range(100))
eval_data = data['test'].select(range(10))
else:
train_data = data['train']
eval_data = data['test']
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
if args.dataset == 'Dahoas/rm-static':
train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
elif args.dataset == 'Anthropic/hh-rlhf':
train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
else:
raise ValueError(f'Unsupported dataset "{args.dataset}"')
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
loss_fn=loss_fn,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
eval_dataset=eval_dataset,
batch_size=args.batch_size,
max_epochs=args.max_epochs)
trainer.fit()
# save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(trainer.optimizer,
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--dataset',
type=str,
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
default='Dahoas/rm-static')
parser.add_argument('--subset', type=str, default=None)
parser.add_argument('--save_path', type=str, default='rm_ckpt')
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--max_len', type=int, default=512)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
parser.add_argument('--test', type=bool, default=False)
args = parser.parse_args()
train(args)