[chat] refactor trainer class (#4080)

* to: add SLTrainer

* refactor: refactor RMTrainer and SFTTrainer

* fix: fix init file

* feat: remove on_learn_epoch fn as not used

* fix: align with modified gemini arguments

* to: add OnPolicyTrainer

* revert: add _on_learn_epoch fn

* refactor: refactor PPOTrainer

* style: rename PPOTrainer argument

* fix: align with modified PPO arguments

* test: align with modified train_prompts arguments

* chore: modify train_prompts

* docs: align with modified arguments

* fix: remove unnecessary output

* fix: move dataloader to fit fn of SLTrainer

* fix: move dataloader to fit fn of OnPolicyTrainer

* fix: modify usage of prompt and pretrain dataloader
This commit is contained in:
Wenhao Chen
2023-06-29 10:48:09 +08:00
committed by GitHub
parent 711e2b4c00
commit b03d64d010
16 changed files with 461 additions and 361 deletions

View File

@@ -171,9 +171,8 @@ Pretrain dataset: the pretrain dataset including the instruction and correspondi
- --pretrain_dataset: path of the ptx dataset, type=str, default=None
- --need_optim_ckpt: whether to save optim ckpt, type=bool, default=False
- --num_episodes: num of episodes for training, type=int, default=10
- --max_epochs: max epochs for training in one episode, type=int, default=5
- --max_timesteps: max episodes in one batch, type=int, default=10
- --update_timesteps: timesteps to update, type=int, default=10
- --num_update_steps: number of steps to update policy per episode, type=int
- --num_collect_steps: number of steps to collect experience per episode, type=int
- --train_batch_size: batch size while training, type=int, default=8
- --ptx_batch_size: batch size to compute ptx loss, type=int, default=1
- --experience_batch_size: batch size to make experience, type=int, default=8

View File

@@ -171,7 +171,6 @@ def main(args):
critic_optim,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
experience_batch_size=args.experience_batch_size,
tokenizer=tokenize_fn,
@@ -186,8 +185,8 @@ def main(args):
trainer.fit(prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps)
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
@@ -215,9 +214,8 @@ if __name__ == '__main__':
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--num_collect_steps', type=int, default=10)
parser.add_argument('--num_update_steps', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=2)
parser.add_argument('--ptx_batch_size', type=int, default=1)
parser.add_argument('--experience_batch_size', type=int, default=8)

View File

@@ -63,8 +63,8 @@ for model in 'gpt2' 'bloom' 'opt' 'llama' 'roberta'; do
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy $strategy --model $model \
--num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
--train_batch_size 2
done
done
@@ -149,8 +149,8 @@ rm -rf ${BASE}/rm_ckpt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--strategy colossalai_zero2 --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'facebook/opt-350m' --model opt \
--rm_pretrain 'facebook/opt-350m' \
--rm_path ${BASE}/rm_ckpt_opt.pt \
@@ -159,8 +159,8 @@ rm -rf ${BASE}/rm_ckpt_opt.pt
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--strategy colossalai_zero2 --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'gpt2' --model gpt2 \
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \
@@ -168,8 +168,8 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py \
--prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
--update_timesteps 2 --max_epochs 1 --train_batch_size 2 \
--strategy colossalai_gemini --num_episodes 1 \
--num_collect_steps 2 --num_update_steps 1 --train_batch_size 2 \
--pretrain 'gpt2' --model gpt2 \
--rm_pretrain 'gpt2' \
--rm_path ${BASE}/rm_ckpt_gpt.pt \

View File

@@ -177,7 +177,6 @@ def main(args):
critic_optim,
kl_coef=args.kl_coef,
ptx_coef=args.ptx_coef,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
max_length=args.max_seq_len,
use_cache=True,
@@ -192,8 +191,8 @@ def main(args):
trainer.fit(prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
num_collect_steps=args.num_collect_steps,
num_update_steps=args.num_update_steps)
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
@@ -220,9 +219,8 @@ if __name__ == '__main__':
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--max_epochs', type=int, default=5)
parser.add_argument('--num_collect_steps', type=int, default=10)
parser.add_argument('--num_update_steps', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--ptx_batch_size', type=int, default=1)
parser.add_argument('--experience_batch_size', type=int, default=8)

View File

@@ -1,13 +1,13 @@
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
@@ -17,4 +17,9 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 2
# torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_dataset /path/to/data.json --strategy colossalai_zero2
torchrun --standalone --nproc_per_node=2 train_prompts.py \
--pretrain_dataset /path/to/data.json \
--prompt_dataset /path/to/data.json \
--strategy colossalai_zero2 \
--num_episodes 1 --num_collect_steps 2 --num_update_steps 1 \
--train_batch_size 2

View File

@@ -178,12 +178,11 @@ def train(args):
optim=optim,
lr_scheduler=lr_scheduler,
loss_fn=loss_fn,
train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader,
max_epochs=args.max_epochs)
trainer.fit()
trainer.fit(train_dataloader=train_dataloader,
valid_dataloader=valid_dataloader,
eval_dataloader=eval_dataloader)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
# save optimizer checkpoint on all ranks

View File

@@ -170,12 +170,13 @@ def train(args):
strategy=strategy,
optim=optim,
lr_scheduler=lr_scheduler,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps)
trainer.fit(logger=logger, use_wandb=args.use_wandb)
trainer.fit(train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
logger=logger,
use_wandb=args.use_wandb)
# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)