mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -27,12 +27,14 @@ class SaveCheckpointHook(BaseHook):
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
interval: int = 1,
|
||||
checkpoint_dir: str = None,
|
||||
model: torch.nn.Module = None,
|
||||
save_by_iter: bool = False,
|
||||
priority: int = 10):
|
||||
def __init__(
|
||||
self,
|
||||
interval: int = 1,
|
||||
checkpoint_dir: str = None,
|
||||
model: torch.nn.Module = None,
|
||||
save_by_iter: bool = False,
|
||||
priority: int = 10,
|
||||
):
|
||||
super().__init__(priority=priority)
|
||||
self.interval = interval
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
@@ -52,22 +54,23 @@ class SaveCheckpointHook(BaseHook):
|
||||
self.model = self.model if self.model is not None else trainer.engine.model
|
||||
|
||||
def after_train_iter(self, trainer, output, label, loss):
|
||||
"""Saves the model after a training iter.
|
||||
"""
|
||||
"""Saves the model after a training iter."""
|
||||
# save by interval
|
||||
if self.save_by_iter and trainer.cur_step % self.interval == 0:
|
||||
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
|
||||
self._lr_scheduler)
|
||||
self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}',
|
||||
ranks=[0])
|
||||
save_checkpoint(
|
||||
self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler
|
||||
)
|
||||
self.logger.info(
|
||||
f"checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}", ranks=[0]
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Saves the model after a training epoch.
|
||||
"""
|
||||
"""Saves the model after a training epoch."""
|
||||
# save by interval
|
||||
if trainer.cur_epoch % self.interval == 0:
|
||||
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
|
||||
self._lr_scheduler)
|
||||
self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])
|
||||
save_checkpoint(
|
||||
self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer, self._lr_scheduler
|
||||
)
|
||||
self.logger.info(f"checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}", ranks=[0])
|
||||
|
Reference in New Issue
Block a user