diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index 5b1b8d3d1..fb9802e38 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -150,9 +150,7 @@ def train(args): pin_memory=True) lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100) - strategy_dict = strategy.prepare( - dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler) - ) + strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)) model = strategy_dict['model'] optim = strategy_dict['optimizer'] lr_scheduler = strategy_dict['lr_scheduler'] @@ -163,9 +161,7 @@ def train(args): loss_fn=loss_fn, max_epochs=args.max_epochs) - trainer.fit(train_dataloader=train_dataloader, - valid_dataloader=valid_dataloader, - eval_dataloader=eval_dataloader) + trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader) # save model checkpoint after fitting on only rank0 strategy.save_model(model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks