mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[chat] refactor trainer class (#4080)
* to: add SLTrainer * refactor: refactor RMTrainer and SFTTrainer * fix: fix init file * feat: remove on_learn_epoch fn as not used * fix: align with modified gemini arguments * to: add OnPolicyTrainer * revert: add _on_learn_epoch fn * refactor: refactor PPOTrainer * style: rename PPOTrainer argument * fix: align with modified PPO arguments * test: align with modified train_prompts arguments * chore: modify train_prompts * docs: align with modified arguments * fix: remove unnecessary output * fix: move dataloader to fit fn of SLTrainer * fix: move dataloader to fit fn of OnPolicyTrainer * fix: modify usage of prompt and pretrain dataloader
This commit is contained in:
@@ -171,9 +171,8 @@ Pretrain dataset: the pretrain dataset including the instruction and correspondi
|
||||
- --pretrain_dataset: path of the ptx dataset, type=str, default=None
|
||||
- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
|
||||
- --num_episodes: num of episodes for training, type=int, default=10
|
||||
- --max_epochs: max epochs for training in one episode, type=int, default=5
|
||||
- --max_timesteps: max episodes in one batch, type=int, default=10
|
||||
- --update_timesteps: timesteps to update, type=int, default=10
|
||||
- --num_update_steps: number of steps to update policy per episode, type=int
|
||||
- --num_collect_steps: number of steps to collect experience per episode, type=int
|
||||
- --train_batch_size: batch size while training, type=int, default=8
|
||||
- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1
|
||||
- --experience_batch_size: batch size to make experience, type=int, default=8
|
||||
|
||||
@@ -171,7 +171,6 @@ def main(args):
|
||||
critic_optim,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
@@ -186,8 +185,8 @@ def main(args):
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
@@ -215,9 +214,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=2)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
|
||||
@@ -63,8 +63,8 @@ for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy $strategy --model $model \
|
||||
--num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2
|
||||
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
|
||||
--train_batch_size 2
|
||||
done
|
||||
done
|
||||
|
||||
@@ -149,8 +149,8 @@ rm -rf ${BASE}/rm_ckpt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--strategy colossalai_zero2 --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'facebook/opt-350m' --model opt \
|
||||
--rm_pretrain 'facebook/opt-350m' \
|
||||
--rm_path ${BASE}/rm_ckpt_opt.pt \
|
||||
@@ -159,8 +159,8 @@ rm -rf ${BASE}/rm_ckpt_opt.pt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--strategy colossalai_zero2 --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
@@ -168,8 +168,8 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
|
||||
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
|
||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
|
||||
--strategy colossalai_gemini --num_episodes 1 \
|
||||
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
|
||||
--pretrain 'gpt2' --model gpt2 \
|
||||
--rm_pretrain 'gpt2' \
|
||||
--rm_path ${BASE}/rm_ckpt_gpt.pt \
|
||||
|
||||
@@ -177,7 +177,6 @@ def main(args):
|
||||
critic_optim,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
max_length=args.max_seq_len,
|
||||
use_cache=True,
|
||||
@@ -192,8 +191,8 @@ def main(args):
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
num_update_steps=args.num_update_steps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
||||
@@ -220,9 +219,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
|
||||
tail -n +2 |
|
||||
nl -v 0 |
|
||||
tee /dev/tty |
|
||||
sort -g -k 2 |
|
||||
awk '{print $1}' |
|
||||
head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
@@ -17,4 +17,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py \
|
||||
--pretrain_dataset /path/to/data.json \
|
||||
--prompt_dataset /path/to/data.json \
|
||||
--strategy colossalai_zero2 \
|
||||
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
|
||||
--train_batch_size 2
|
||||
|
||||
@@ -178,12 +178,11 @@ def train(args):
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
loss_fn=loss_fn,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
max_epochs=args.max_epochs)
|
||||
|
||||
trainer.fit()
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
eval_dataloader=eval_dataloader)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
|
||||
@@ -170,12 +170,13 @@ def train(args):
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
lr_scheduler=lr_scheduler,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps)
|
||||
|
||||
trainer.fit(logger=logger, use_wandb=args.use_wandb)
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
logger=logger,
|
||||
use_wandb=args.use_wandb)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user