From 80e37eec4215cd81a54a525af14a1030dd6177a1 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Thu, 14 Apr 2022 21:03:24 +0800 Subject: [PATCH] fix the ckpt bugs when using DDP (#769) --- colossalai/trainer/hooks/_checkpoint_hook.py | 31 ++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py index 92a2118c2..d260ddcbf 100644 --- a/colossalai/trainer/hooks/_checkpoint_hook.py +++ b/colossalai/trainer/hooks/_checkpoint_hook.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - +import torch from colossalai.logging import get_dist_logger from colossalai.registry import HOOKS @@ -15,7 +15,12 @@ class SaveCheckpointHook(BaseHook): Args: interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1. + if save_by_iter is True, this arg refers to the number of iters between saving. checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None. + model (torch.nn.Module, Optional): The model to save, defaults to None. When not passing, + 'trainer.engine.model' will be used. We encourage you to pass the model in it to avoid some + unexpected bugs, especially when using **DDP**. + save_by_iter (bool, optional): Whether saving the checkpoint by iter, default to False. priority (int, optional): Priority in the printing, hooks with small priority will be printed in front defaults to 10. If different hooks share same priority, the order of printing would depend on the hooks order in the hook list. @@ -24,10 +29,14 @@ class SaveCheckpointHook(BaseHook): def __init__(self, interval: int = 1, checkpoint_dir: str = None, + model: torch.nn.Module = None, + save_by_iter: bool = False, priority: int = 10): super().__init__(priority=priority) self.interval = interval self.checkpoint_dir = checkpoint_dir + self.model = model + self.save_by_iter = save_by_iter self.logger = get_dist_logger() # get lr scheduler from the LRSchedulerHook before train @@ -39,6 +48,24 @@ class SaveCheckpointHook(BaseHook): if isinstance(hook, LRSchedulerHook): self._lr_scheduler = hook.lr_scheduler break + self.model = self.model if self.model is not None else trainer.engine.model + + + def after_train_iter(self, trainer, output, label, loss): + """Saves the model after a training iter. + """ + # save by interval + if self.save_by_iter and trainer.cur_step % self.interval == 0: + save_checkpoint(self.checkpoint_dir, + trainer.cur_epoch, + self.model, + trainer.engine.optimizer, + self._lr_scheduler) + self.logger.info( + f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}', ranks=[0]) + else: + pass + def after_train_epoch(self, trainer): """Saves the model after a training epoch. @@ -47,7 +74,7 @@ class SaveCheckpointHook(BaseHook): if trainer.cur_epoch % self.interval == 0: save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, - trainer.engine.model, + self.model, trainer.engine.optimizer, self._lr_scheduler) self.logger.info(