diff --git a/applications/ColossalChat/coati/trainer/rm.py b/applications/ColossalChat/coati/trainer/rm.py index 849a90a27..82e4625b9 100755 --- a/applications/ColossalChat/coati/trainer/rm.py +++ b/applications/ColossalChat/coati/trainer/rm.py @@ -48,7 +48,7 @@ class RewardModelTrainer(SLTrainer): model: Any, booster: Booster, optimizer: Optimizer, - plugin: Plugin, + plugin: Plugin, lr_scheduler: _LRScheduler, tokenizer: PreTrainedTokenizerBase, loss_fn: Optional[Callable] = None, @@ -60,7 +60,9 @@ class RewardModelTrainer(SLTrainer): save_dir: str = None, coordinator: DistCoordinator = None, ) -> None: - super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch) + super().__init__( + booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch + ) self.actor_scheduler = lr_scheduler self.tokenizer = tokenizer self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)