mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-17 23:46:52 +00:00
[chatgpt]add reward model code for deberta (#3199)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
parent
1e1b9d2fea
commit
9998d5ef64
4
applications/ChatGPT/chatgpt/models/deberta/__init__.py
Normal file
4
applications/ChatGPT/chatgpt/models/deberta/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .deberta_critic import DebertaCritic
|
||||||
|
from .deberta_rm import DebertaRM
|
||||||
|
|
||||||
|
__all__ = ['DebertaCritic', 'DebertaRM']
|
@ -0,0 +1,36 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import DebertaV2Config, DebertaV2Model
|
||||||
|
|
||||||
|
from ..base import Critic
|
||||||
|
|
||||||
|
|
||||||
|
class DebertaCritic(Critic):
|
||||||
|
"""
|
||||||
|
Deberta Critic model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (str): Pretrained model name or path.
|
||||||
|
config (DebertaV2Config): Model config.
|
||||||
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
|
lora_rank (int): Rank of the LO-RA decomposition.
|
||||||
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
pretrained: Optional[str] = None,
|
||||||
|
config: Optional[DebertaV2Config] = None,
|
||||||
|
checkpoint: bool = False,
|
||||||
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none') -> None:
|
||||||
|
if pretrained is not None:
|
||||||
|
model = DebertaV2Model.from_pretrained(pretrained)
|
||||||
|
elif config is not None:
|
||||||
|
model = DebertaV2Model(config)
|
||||||
|
else:
|
||||||
|
model = DebertaV2Model(DebertaV2Config())
|
||||||
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||||
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
37
applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py
Normal file
37
applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import DebertaV2Config, DebertaV2Model
|
||||||
|
|
||||||
|
from ..base import RewardModel
|
||||||
|
|
||||||
|
|
||||||
|
class DebertaRM(RewardModel):
|
||||||
|
"""
|
||||||
|
Deberta Reward model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained (str): Pretrained model name or path.
|
||||||
|
config (DebertaV2Config): Model config.
|
||||||
|
checkpoint (bool): Enable gradient checkpointing.
|
||||||
|
lora_rank (int): Rank of the LO-RA decomposition.
|
||||||
|
lora_train_bias (str): LoRA bias training mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
pretrained: str = None,
|
||||||
|
config: Optional[DebertaV2Config] = None,
|
||||||
|
checkpoint: bool = False,
|
||||||
|
lora_rank: int = 0,
|
||||||
|
lora_train_bias: str = 'none') -> None:
|
||||||
|
if pretrained is not None:
|
||||||
|
model = DebertaV2Model.from_pretrained(pretrained)
|
||||||
|
elif config is not None:
|
||||||
|
model = DebertaV2Model(config)
|
||||||
|
else:
|
||||||
|
model = DebertaV2Model(DebertaV2Config())
|
||||||
|
if checkpoint:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||||
|
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||||
|
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
@ -1 +1,2 @@
|
|||||||
pandas>=1.4.1
|
pandas>=1.4.1
|
||||||
|
sentencepiece
|
||||||
|
@ -88,4 +88,10 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
|||||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
||||||
--test True --lora_rank 4
|
--test True --lora_rank 4
|
||||||
|
|
||||||
|
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||||
|
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
|
||||||
|
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||||
|
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
||||||
|
--test True --lora_rank 4
|
||||||
|
|
||||||
rm -rf ${BASE}/rm_ckpt.pt
|
rm -rf ${BASE}/rm_ckpt.pt
|
||||||
|
@ -8,12 +8,13 @@ from chatgpt.models.base import RewardModel
|
|||||||
from chatgpt.models.bloom import BLOOMRM
|
from chatgpt.models.bloom import BLOOMRM
|
||||||
from chatgpt.models.gpt import GPTRM
|
from chatgpt.models.gpt import GPTRM
|
||||||
from chatgpt.models.opt import OPTRM
|
from chatgpt.models.opt import OPTRM
|
||||||
|
from chatgpt.models.deberta import DebertaRM
|
||||||
from chatgpt.trainer import RewardModelTrainer
|
from chatgpt.trainer import RewardModelTrainer
|
||||||
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from random import randint
|
from random import randint
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer
|
||||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
@ -39,6 +40,8 @@ def train(args):
|
|||||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
elif args.model == 'gpt2':
|
elif args.model == 'gpt2':
|
||||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
|
elif args.model == 'deberta':
|
||||||
|
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
|
|
||||||
@ -54,6 +57,8 @@ def train(args):
|
|||||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||||
elif args.model == 'opt':
|
elif args.model == 'opt':
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
|
elif args.model == 'deberta':
|
||||||
|
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unsupported model "{args.model}"')
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||||||
max_len = args.max_len
|
max_len = args.max_len
|
||||||
@ -119,7 +124,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--strategy',
|
parser.add_argument('--strategy',
|
||||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||||
default='naive')
|
default='naive')
|
||||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom')
|
||||||
parser.add_argument('--pretrain', type=str, default=None)
|
parser.add_argument('--pretrain', type=str, default=None)
|
||||||
parser.add_argument('--model_path', type=str, default=None)
|
parser.add_argument('--model_path', type=str, default=None)
|
||||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||||
|
|
||||||
python train_reward_model.py --pretrain '/home/lczht/data2/bloom-560m' \
|
python train_reward_model.py --pretrain 'microsoft/deberta-v3-large' \
|
||||||
--model 'bloom' \
|
--model 'deberta' \
|
||||||
--strategy naive \
|
--strategy naive \
|
||||||
--loss_fn 'log_exp'\
|
--loss_fn 'log_exp'\
|
||||||
--save_path 'rmstatic.pt' \
|
--save_path 'rmstatic.pt' \
|
||||||
|
Loading…
Reference in New Issue
Block a user