[chat] refactor trainer class (#4080)

* to: add SLTrainer

* refactor: refactor RMTrainer and SFTTrainer

* fix: fix init file

* feat: remove on_learn_epoch fn as not used

* fix: align with modified gemini arguments

* to: add OnPolicyTrainer

* revert: add _on_learn_epoch fn

* refactor: refactor PPOTrainer

* style: rename PPOTrainer argument

* fix: align with modified PPO arguments

* test: align with modified train_prompts arguments

* chore: modify train_prompts

* docs: align with modified arguments

* fix: remove unnecessary output

* fix: move dataloader to fit fn of SLTrainer

* fix: move dataloader to fit fn of OnPolicyTrainer

* fix: modify usage of prompt and pretrain dataloader
This commit is contained in:
Wenhao Chen
2023-06-29 10:48:09 +08:00
committed by GitHub
parent 711e2b4c00
commit b03d64d010
16 changed files with 461 additions and 361 deletions

View File

@@ -137,6 +137,12 @@ def main(args):
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
dataloader = DataLoader(random_prompts,
batch_size=args.experience_batch_size,
shuffle=True,
collate_fn=preprocess_batch)
trainer = PPOTrainer(strategy,
actor,
critic,
@@ -145,7 +151,6 @@ def main(args):
actor_optim,
critic_optim,
ptx_coef=0,
max_epochs=args.max_epochs,
train_batch_size=args.train_batch_size,
offload_inference_models=args.offload_inference_models,
max_length=512,
@@ -157,17 +162,11 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
dataloader = DataLoader(random_prompts,
batch_size=args.experience_batch_size,
shuffle=True,
collate_fn=preprocess_batch)
trainer.fit(dataloader,
None,
trainer.fit(prompt_dataloader=dataloader,
pretrain_dataloader=None,
num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps)
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps)
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
@@ -183,9 +182,8 @@ if __name__ == '__main__':
],
default='ddp')
parser.add_argument('--num_episodes', type=int, default=3)
parser.add_argument('--max_timesteps', type=int, default=8)
parser.add_argument('--update_timesteps', type=int, default=8)
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--num_collect_steps', type=int, default=8)
parser.add_argument('--num_update_steps', type=int, default=1)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0)