mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-06 18:43:58 +00:00
[NFC] polish applications/Chat/examples/train_reward_model.py code style (#4271)
This commit is contained in:
parent
a50d39a143
commit
1ce997daaf
@ -150,9 +150,7 @@ def train(args):
|
|||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
|
|
||||||
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
|
lr_scheduler = CosineAnnealingLR(optim, train_dataloader.__len__() // 100)
|
||||||
strategy_dict = strategy.prepare(
|
strategy_dict = strategy.prepare(dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler))
|
||||||
dict(model=model, optimizer=optim, lr_scheduler=lr_scheduler)
|
|
||||||
)
|
|
||||||
model = strategy_dict['model']
|
model = strategy_dict['model']
|
||||||
optim = strategy_dict['optimizer']
|
optim = strategy_dict['optimizer']
|
||||||
lr_scheduler = strategy_dict['lr_scheduler']
|
lr_scheduler = strategy_dict['lr_scheduler']
|
||||||
@ -163,9 +161,7 @@ def train(args):
|
|||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
max_epochs=args.max_epochs)
|
max_epochs=args.max_epochs)
|
||||||
|
|
||||||
trainer.fit(train_dataloader=train_dataloader,
|
trainer.fit(train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, eval_dataloader=eval_dataloader)
|
||||||
valid_dataloader=valid_dataloader,
|
|
||||||
eval_dataloader=eval_dataloader)
|
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||||
# save optimizer checkpoint on all ranks
|
# save optimizer checkpoint on all ranks
|
||||||
|
Loading…
Reference in New Issue
Block a user