[chat] refactor trainer (#3648)

* [chat] ppo trainer remove useless args

* [chat] update examples

* [chat] update benchmark

* [chat] update examples

* [chat] fix sft training with wandb

* [chat] polish docstr
This commit is contained in:
Hongxin Liu
2023-04-26 18:11:49 +08:00
committed by GitHub
parent f8288315d9
commit 2a951955ad
12 changed files with 72 additions and 536 deletions

View File

@@ -1,156 +0,0 @@
import argparse
from copy import deepcopy
import torch
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTActor, GPTCritic
from coati.models.opt import OPTActor, OPTCritic
from coati.models.roberta import RoBERTaActor, RoBERTaCritic
from coati.trainer import PPOTrainer
from coati.trainer.callbacks import SaveCheckpoint
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
def preprocess_batch(samples):
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
return {'input_ids': input_ids, 'attention_mask': attention_mask}
def main(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
if args.model == 'gpt2':
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'bloom':
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'opt':
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'roberta':
actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = RoBERTaCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')
initial_model = deepcopy(actor).to(torch.cuda.current_device())
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
# configure optimizer
if args.strategy.startswith('colossalai'):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif args.model == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
else:
raise ValueError(f'Unsupported model "{args.model}"')
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
callbacks = []
if args.save_ckpt_path:
ckpt_callback = SaveCheckpoint(
args.save_ckpt_path,
args.save_ckpt_interval,
strategy,
actor,
critic,
actor_optim,
critic_optim,
)
callbacks.append(ckpt_callback)
# configure trainer
trainer = PPOTrainer(strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
tokenizer=preprocess_batch,
max_length=128,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
callbacks=callbacks)
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 64), device=torch.cuda.current_device())
random_attention_mask = torch.randint(1, (1000, 1, 64), device=torch.cuda.current_device()).to(torch.bool)
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
trainer.fit(random_prompts, random_pretrain,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=50)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--save_ckpt_path',
type=str,
default=None,
help="path to save checkpoint, None means not to save")
parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint")
args = parser.parse_args()
main(args)

View File

@@ -1,18 +0,0 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 2
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2

View File

@@ -71,9 +71,8 @@ def main(args):
if args.rm_path is not None:
reward_model.load_state_dict(state_dict)
if args.strategy != 'colossalai_gemini':
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
if args.model == 'gpt2':
@@ -148,9 +147,12 @@ def main(args):
prompt_dataloader = DataLoader(prompt_dataset,
shuffle=(prompt_sampler is None),
sampler=prompt_sampler,
batch_size=args.train_batch_size)
batch_size=args.experience_batch_size)
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384)
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
data_path=args.pretrain_dataset,
max_datasets_size=16384,
max_length=args.max_input_len)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
@@ -161,12 +163,6 @@ def main(args):
batch_size=args.ptx_batch_size,
collate_fn=data_collator)
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
# configure trainer
@@ -182,9 +178,8 @@ def main(args):
ptx_coef=args.ptx_coef,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=tokenize_fn,
max_length=128,
max_length=args.max_seq_len,
use_cache=True,
do_sample=True,
temperature=1.0,
top_k=50,
@@ -232,5 +227,7 @@ if __name__ == '__main__':
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--kl_coef', type=float, default=0.1)
parser.add_argument('--ptx_coef', type=float, default=0.9)
parser.add_argument('--max_input_len', type=int, default=96)
parser.add_argument('--max_seq_len', type=int, default=128)
args = parser.parse_args()
main(args)

View File

@@ -156,7 +156,7 @@ def train(args):
max_epochs=args.max_epochs,
accimulation_steps=args.accimulation_steps)
trainer.fit(logger=logger, log_interval=args.log_interval)
trainer.fit(logger=logger, use_wandb=args.use_wandb)
# save model checkpoint after fitting on only rank0
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
@@ -185,5 +185,6 @@ if __name__ == '__main__':
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
parser.add_argument('--lr', type=float, default=5e-6)
parser.add_argument('--accimulation_steps', type=int, default=8)
parser.add_argument('--use_wandb', default=False, action='store_true')
args = parser.parse_args()
train(args)