mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[chat] refactor trainer class (#4080)
* to: add SLTrainer * refactor: refactor RMTrainer and SFTTrainer * fix: fix init file * feat: remove on_learn_epoch fn as not used * fix: align with modified gemini arguments * to: add OnPolicyTrainer * revert: add _on_learn_epoch fn * refactor: refactor PPOTrainer * style: rename PPOTrainer argument * fix: align with modified PPO arguments * test: align with modified train_prompts arguments * chore: modify train_prompts * docs: align with modified arguments * fix: remove unnecessary output * fix: move dataloader to fit fn of SLTrainer * fix: move dataloader to fit fn of OnPolicyTrainer * fix: modify usage of prompt and pretrain dataloader
This commit is contained in:
@@ -137,6 +137,12 @@ def main(args):
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
||||
dataloader = DataLoader(random_prompts,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=preprocess_batch)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
@@ -145,7 +151,6 @@ def main(args):
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
ptx_coef=0,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
offload_inference_models=args.offload_inference_models,
|
||||
max_length=512,
|
||||
@@ -157,17 +162,11 @@ def main(args):
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
||||
dataloader = DataLoader(random_prompts,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=preprocess_batch)
|
||||
|
||||
trainer.fit(dataloader,
|
||||
None,
|
||||
trainer.fit(prompt_dataloader=dataloader,
|
||||
pretrain_dataloader=None,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
|
||||
@@ -183,9 +182,8 @@ if __name__ == '__main__':
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=8)
|
||||
parser.add_argument('--num_update_steps', type=int, default=1)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0)
|
||||
|
Reference in New Issue
Block a user