mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[chat] remove lm model class (#3653)
* [chat] refactor lora * [chat] remove lm class * [chat] refactor save model * [chat] refactor train sft * [chat] fix ci * [chat] fix ci
This commit is contained in:
@@ -31,16 +31,19 @@ torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'bigsci
|
||||
--model 'bloom' --strategy colossalai_zero2 --lora_rank 4\
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
||||
--model 'gpt2' --strategy colossalai_zero2 \
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'facebook/opt-350m' \
|
||||
--model 'opt' --strategy colossalai_zero2 --lora_rank 4\
|
||||
--dataset $SFT_DATASET --max_datasets_size 512 --max_epochs 1 \
|
||||
--save_path ${BASE}/output
|
||||
rm -rf ${BASE}/output
|
||||
|
||||
torchrun --standalone --nproc_per_node=4 ${BASE}/train_sft.py --pretrain 'gpt2' \
|
||||
--model 'gpt2' --strategy ddp --lora_rank 4\
|
||||
@@ -59,14 +62,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'facebook/opt-350m' --model 'opt' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_opt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_exp' \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 4 \
|
||||
--test True --lora_rank 0 \
|
||||
--save_path ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
@@ -75,6 +78,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--dataset 'Dahoas/rm-static' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
||||
@@ -82,6 +86,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
|
||||
@@ -89,6 +94,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base' \
|
||||
--test True --lora_rank 4 \
|
||||
--save_path ${BASE}/rm_ckpt.pt
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'roberta-base' --model 'roberta' \
|
||||
@@ -117,4 +123,4 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/rm_ckpt_gpt.pt
|
||||
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
@@ -5,11 +5,7 @@ import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMLM
|
||||
from coati.models.gpt import GPTLM
|
||||
from coati.models.llama import LlamaLM
|
||||
from coati.models.opt import OPTLM
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
@@ -17,8 +13,12 @@ from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, BloomTokenizerFast, LlamaConfig, LlamaForCausalLM
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
@@ -32,6 +32,8 @@ def train(args):
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
raise NotImplementedError(
|
||||
'Gemini is not supported .from_pretrained() yet. We will update this after checkpoint io is ready.')
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
@@ -43,16 +45,19 @@ def train(args):
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = convert_to_lora_module(BloomForCausalLM.from_pretrained(args.pretrain),
|
||||
args.lora_rank).half().cuda()
|
||||
elif args.model == 'opt':
|
||||
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = convert_to_lora_module(OPTForCausalLM.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
model = convert_to_lora_module(GPT2LMHeadModel.from_pretrained(args.pretrain), args.lora_rank).half().cuda()
|
||||
elif args.model == 'llama':
|
||||
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank,
|
||||
checkpoint=True).to(torch.float16).to(torch.cuda.current_device())
|
||||
model = convert_to_lora_module(LlamaForCausalLM.from_pretrained(args.pretrain),
|
||||
args.lora_rank).half().cuda()
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
if args.grad_checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
@@ -152,7 +157,6 @@ def train(args):
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accimulation_steps=args.accimulation_steps)
|
||||
|
||||
@@ -186,5 +190,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accimulation_steps', type=int, default=8)
|
||||
parser.add_argument('--use_wandb', default=False, action='store_true')
|
||||
parser.add_argument('--grad_checkpoint', default=False, action='store_true')
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
Reference in New Issue
Block a user