From 730f88f8e1c1a034ada008d7006922bf07571377 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Tue, 18 Oct 2022 09:37:30 +0800 Subject: [PATCH] [NFC] polish _checkpoint_hook.py code style (#1722) --- colossalai/trainer/hooks/_checkpoint_hook.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py index d260ddcbf..3bcb32cd2 100644 --- a/colossalai/trainer/hooks/_checkpoint_hook.py +++ b/colossalai/trainer/hooks/_checkpoint_hook.py @@ -50,32 +50,23 @@ class SaveCheckpointHook(BaseHook): break self.model = self.model if self.model is not None else trainer.engine.model - def after_train_iter(self, trainer, output, label, loss): """Saves the model after a training iter. """ # save by interval if self.save_by_iter and trainer.cur_step % self.interval == 0: - save_checkpoint(self.checkpoint_dir, - trainer.cur_epoch, - self.model, - trainer.engine.optimizer, + save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler) - self.logger.info( - f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', ranks=[0]) + self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', + ranks=[0]) else: pass - def after_train_epoch(self, trainer): """Saves the model after a training epoch. """ # save by interval if trainer.cur_epoch % self.interval == 0: - save_checkpoint(self.checkpoint_dir, - trainer.cur_epoch, - self.model, - trainer.engine.optimizer, + save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler) - self.logger.info( - f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0]) + self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])