mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[chat] refactor trainer (#3648)
* [chat] ppo trainer remove useless args * [chat] update examples * [chat] update benchmark * [chat] update examples * [chat] fix sft training with wandb * [chat] polish docstr
This commit is contained in:
@@ -1,156 +0,0 @@
|
||||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.models.roberta import RoBERTaActor, RoBERTaCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import SaveCheckpoint
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def preprocess_batch(samples):
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'roberta':
|
||||
actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = RoBERTaCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
initial_model = deepcopy(actor).to(torch.cuda.current_device())
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif args.model == 'roberta':
|
||||
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
callbacks = []
|
||||
if args.save_ckpt_path:
|
||||
ckpt_callback = SaveCheckpoint(
|
||||
args.save_ckpt_path,
|
||||
args.save_ckpt_interval,
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
)
|
||||
callbacks.append(ckpt_callback)
|
||||
|
||||
# configure trainer
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=callbacks)
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 1, 64), device=torch.cuda.current_device())
|
||||
random_attention_mask = torch.randint(1, (1000, 1, 64), device=torch.cuda.current_device()).to(torch.bool)
|
||||
random_pretrain = [{'input_ids':random_prompts[i], 'labels':random_prompts[i], 'attention_mask':random_attention_mask[i]} for i in range(1000)]
|
||||
trainer.fit(random_prompts, random_pretrain,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=50)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--save_ckpt_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to save checkpoint, None means not to save")
|
||||
parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,18 +0,0 @@
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2
|
||||
@@ -71,9 +71,8 @@ def main(args):
|
||||
if args.rm_path is not None:
|
||||
reward_model.load_state_dict(state_dict)
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
@@ -148,9 +147,12 @@ def main(args):
|
||||
prompt_dataloader = DataLoader(prompt_dataset,
|
||||
shuffle=(prompt_sampler is None),
|
||||
sampler=prompt_sampler,
|
||||
batch_size=args.train_batch_size)
|
||||
batch_size=args.experience_batch_size)
|
||||
|
||||
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384)
|
||||
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer,
|
||||
data_path=args.pretrain_dataset,
|
||||
max_datasets_size=16384,
|
||||
max_length=args.max_input_len)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
@@ -161,12 +163,6 @@ def main(args):
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator)
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
# configure trainer
|
||||
@@ -182,9 +178,8 @@ def main(args):
|
||||
ptx_coef=args.ptx_coef,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
max_length=128,
|
||||
max_length=args.max_seq_len,
|
||||
use_cache=True,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
@@ -232,5 +227,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--kl_coef', type=float, default=0.1)
|
||||
parser.add_argument('--ptx_coef', type=float, default=0.9)
|
||||
parser.add_argument('--max_input_len', type=int, default=96)
|
||||
parser.add_argument('--max_seq_len', type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -156,7 +156,7 @@ def train(args):
|
||||
max_epochs=args.max_epochs,
|
||||
accimulation_steps=args.accimulation_steps)
|
||||
|
||||
trainer.fit(logger=logger, log_interval=args.log_interval)
|
||||
trainer.fit(logger=logger, use_wandb=args.use_wandb)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
@@ -185,5 +185,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accimulation_steps', type=int, default=8)
|
||||
parser.add_argument('--use_wandb', default=False, action='store_true')
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
Reference in New Issue
Block a user