mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-05 21:52:06 +00:00
[chatgpt]support opt & gpt for rm training (#2876)
This commit is contained in:
parent
c52edcf0eb
commit
2e16f842a9
@ -1,6 +1,5 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||||
|
|
||||||
|
@ -15,12 +15,16 @@ class GPTRM(RewardModel):
|
|||||||
pretrained (str): Pretrained model name or path.
|
pretrained (str): Pretrained model name or path.
|
||||||
config (GPT2Config): Model config.
|
config (GPT2Config): Model config.
|
||||||
checkpoint (bool): Enable gradient checkpointing.
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
|
lora_rank (int): Rank of the low-rank approximation.
|
||||||
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
pretrained: Optional[str] = None,
|
pretrained: Optional[str] = None,
|
||||||
config: Optional[GPT2Config] = None,
|
config: Optional[GPT2Config] = None,
|
||||||
checkpoint: bool = False) -> None:
|
checkpoint: bool = False,
|
||||||
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none') -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
model = GPT2Model.from_pretrained(pretrained)
|
model = GPT2Model.from_pretrained(pretrained)
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
@ -29,5 +33,6 @@ class GPTRM(RewardModel):
|
|||||||
model = GPT2Model(GPT2Config())
|
model = GPT2Model(GPT2Config())
|
||||||
if checkpoint:
|
if checkpoint:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
value_head = nn.Linear(model.config.n_embd, 1)
|
value_head = nn.Linear(model.config.n_embd, 1)
|
||||||
super().__init__(model, value_head)
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.opt.configuration_opt import OPTConfig
|
from transformers import OPTConfig, OPTModel
|
||||||
from transformers.models.opt.modeling_opt import OPTModel
|
|
||||||
|
|
||||||
from .reward_model import RewardModel
|
from .reward_model import RewardModel
|
||||||
|
|
||||||
@ -14,6 +13,7 @@ class OPTRM(RewardModel):
|
|||||||
Args:
|
Args:
|
||||||
pretrained (str): Pretrained model name or path.
|
pretrained (str): Pretrained model name or path.
|
||||||
config (OPTConfig): Model config.
|
config (OPTConfig): Model config.
|
||||||
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
lora_rank (int): Rank of the low-rank approximation.
|
lora_rank (int): Rank of the low-rank approximation.
|
||||||
lora_train_bias (str): LoRA bias training mode.
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
"""
|
"""
|
||||||
@ -21,6 +21,7 @@ class OPTRM(RewardModel):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
pretrained: Optional[str] = None,
|
pretrained: Optional[str] = None,
|
||||||
config: Optional[OPTConfig] = None,
|
config: Optional[OPTConfig] = None,
|
||||||
|
checkpoint: bool = False,
|
||||||
lora_rank: int = 0,
|
lora_rank: int = 0,
|
||||||
lora_train_bias: str = 'none') -> None:
|
lora_train_bias: str = 'none') -> None:
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
@ -29,5 +30,8 @@ class OPTRM(RewardModel):
|
|||||||
model = OPTModel(config)
|
model = OPTModel(config)
|
||||||
else:
|
else:
|
||||||
model = OPTModel(OPTConfig())
|
model = OPTModel(OPTConfig())
|
||||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
||||||
|
@ -3,12 +3,13 @@ import argparse
|
|||||||
import loralib as lora
|
import loralib as lora
|
||||||
import torch
|
import torch
|
||||||
from chatgpt.dataset import RewardDataset
|
from chatgpt.dataset import RewardDataset
|
||||||
from chatgpt.nn import BLOOMRM
|
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
|
||||||
from chatgpt.trainer import RewardModelTrainer
|
from chatgpt.trainer import RewardModelTrainer
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from transformers import BloomTokenizerFast
|
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||||
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
|
||||||
@ -27,11 +28,30 @@ def train(args):
|
|||||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||||
|
|
||||||
# configure model
|
# configure model
|
||||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
with strategy.model_init_context():
|
with strategy.model_init_context():
|
||||||
model = BLOOMRM(pretrained=args.pretrain).cuda()
|
if args.model == 'bloom':
|
||||||
max_len = 1024
|
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||||
|
elif args.model == 'opt':
|
||||||
|
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||||
|
elif args.model == 'gpt2':
|
||||||
|
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
|
# configure tokenizer
|
||||||
|
if args.model == 'gpt2':
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'bloom':
|
||||||
|
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
elif args.model == 'opt':
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
max_len = 512
|
||||||
|
|
||||||
# configure optimizer
|
# configure optimizer
|
||||||
if args.strategy.startswith('colossalai'):
|
if args.strategy.startswith('colossalai'):
|
||||||
@ -58,10 +78,10 @@ def train(args):
|
|||||||
|
|
||||||
trainer.fit(use_lora=args.lora_rank)
|
trainer.fit(use_lora=args.lora_rank)
|
||||||
|
|
||||||
if args.lora_rank > 0:
|
# save model checkpoint after fitting on only rank0
|
||||||
torch.save({'model_state_dict': lora.lora_state_dict(trainer.model)}, args.save_path)
|
strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
|
||||||
else:
|
# save optimizer checkpoint on all ranks
|
||||||
torch.save(trainer.model, args.save_path)
|
strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -69,6 +89,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--strategy',
|
parser.add_argument('--strategy',
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
default='naive')
|
default='naive')
|
||||||
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
|
||||||
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
|
||||||
|
@ -15,4 +15,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||||||
|
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||||
|
|
||||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2
|
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2
|
||||||
|
torchrun --standalone --nproc_per_node=2 train_reward_model.py --model 'gpt2' --strategy colossalai_zero2
|
||||||
|
# torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
|
||||||
|
Loading…
Reference in New Issue
Block a user