[chat] remove lm model class (#3653)

* [chat] refactor lora

* [chat] remove lm class

* [chat] refactor save model

* [chat] refactor train sft

* [chat] fix ci

* [chat] fix ci
This commit is contained in:
Hongxin Liu
2023-04-27 15:37:38 +08:00
committed by GitHub
parent 8bccb72c8d
commit 6ef7011462
20 changed files with 84 additions and 284 deletions

View File

@@ -31,16 +31,19 @@ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigsci
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4\
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
--model 'gpt2' --strategy colossalai_zero2 \
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
--model 'opt' --strategy colossalai_zero2 --lora_rank 4\
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
--save_path ${BASE}/output
rm -rf ${BASE}/output
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
--model 'gpt2' --strategy ddp --lora_rank 4\
@@ -59,14 +62,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'facebook/opt-350m' --model 'opt' \
--strategy colossalai_zero2 --loss_fn 'log_sig'\
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \
--test True --lora_rank 0 \
--save_path ${BASE}/rm_ckpt_opt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'gpt2' --model 'gpt2' \
--strategy colossalai_zero2 --loss_fn 'log_exp' \
--dataset 'Dahoas/rm-static' \
--test True --lora_rank 4 \
--test True --lora_rank 0 \
--save_path ${BASE}/rm_ckpt_gpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
@@ -75,6 +78,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Dahoas/rm-static' \
--test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
@@ -82,6 +86,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
@@ -89,6 +94,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
--test True --lora_rank 4 \
--save_path ${BASE}/rm_ckpt.pt
rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
--pretrain 'roberta-base' --model 'roberta' \
@@ -117,4 +123,4 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
--save_path ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/rm_ckpt_gpt.pt
rm -rf ${BASE}/actor_checkpoint_prompts.pt
rm -rf ${BASE}/actor_checkpoint_prompts.pt

View File

@@ -5,11 +5,7 @@ import loralib as lora
import torch
import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMLM
from coati.models.gpt import GPTLM
from coati.models.llama import LlamaLM
from coati.models.opt import OPTLM
from coati.models import convert_to_lora_module
from coati.trainer import SFTTrainer
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from coati.utils import prepare_llama_tokenizer_and_embedding
@@ -17,8 +13,12 @@ from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.models.opt.configuration_opt import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
@@ -32,6 +32,8 @@ def train(args):
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
raise NotImplementedError(
'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
@@ -43,16 +45,19 @@ def train(args):
# configure model
with strategy.model_init_context():
if args.model == 'bloom':
model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain),
args.lora_rank).half().cuda()
elif args.model == 'opt':
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
elif args.model == 'gpt2':
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
elif args.model == 'llama':
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank,
checkpoint=True).to(torch.float16).to(torch.cuda.current_device())
model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain),
args.lora_rank).half().cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
# configure tokenizer
if args.model == 'gpt2':
@@ -152,7 +157,6 @@ def train(args):
optim=optim,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
batch_size=args.batch_size,
max_epochs=args.max_epochs,
accimulation_steps=args.accimulation_steps)
@@ -186,5 +190,6 @@ if __name__ == '__main__':
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--accimulation_steps', type=int, default=8)
parser.add_argument('--use_wandb', default=False, action='store_true')
parser.add_argument('--grad_checkpoint', default=False, action='store_true')
args = parser.parse_args()
train(args)