Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
Frank Lee
2021-11-18 19:45:06 +08:00
committed by GitHub
parent 2b05de4c64
commit 3defa32aee
80 changed files with 2194 additions and 1584 deletions

View File

@@ -1,5 +1,5 @@
from ._trainer import Trainer
from .hooks import *
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D, LearningRate
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D']
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D', 'LearningRate']

View File

@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
from typing import Union, List
import torch
@@ -10,12 +9,11 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from colossalai.builder import build_hooks
from colossalai.checkpointing import save_checkpoint, load_checkpoint, get_checkpoint_path
from colossalai.context import Config
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from colossalai.nn.data import DataParallelSampler
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
class Trainer:
@@ -30,43 +28,31 @@ class Trainer:
:type hoooks_cfg: Config, optional
:type verbose: bool, optional
"""
def __init__(self,
engine: Engine,
hooks_cfg: Optional[Config] = None,
verbose: bool = False):
verbose: bool = False,
timer: MultiTimer = None):
# training-ralated params
self._engine = engine
self._max_epochs = float('inf')
self._max_steps = float('inf')
self._max_epochs = 0
self._cur_epoch = 0
self._max_steps = 0
self._cur_step = 0
# data-related params
self._train_dataloader = None
self._test_dataloader = None
self._steps_per_epoch = 0
# misc params
self._display_progress = False
self._logger = get_global_dist_logger()
self._verbose = verbose
# hooks can store states in this dict, and could be consumed by other hooks
self.states = {}
self.states = dict()
# build hooks
self.hooks = list()
if hooks_cfg is not None:
for cfg in hooks_cfg:
hook = build_hooks(cfg, self)
self.hooks.append(hook)
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
# timer
self._timer = get_global_multitimer()
# multi-timer for time benchmarking
self._timer = timer
@property
def cur_epoch(self):
@@ -74,13 +60,65 @@ class Trainer:
"""
return self._cur_epoch
@cur_epoch.setter
def cur_epoch(self, epoch: int):
"""Set how many epochs have been processed.
"""
# allow setter for training resumption
self._cur_epoch = epoch
@property
def cur_step(self):
"""Returns how many iteration steps have been processed.
"""
return self._cur_step
def call_hooks(self, func, output=None):
@property
def max_epochs(self):
return self._max_epochs
@property
def max_steps(self):
return self._max_steps
@property
def steps_per_epoch(self):
return self._steps_per_epoch
@property
def engine(self):
return self._engine
@engine.setter
def engine(self, engine_: Engine):
self._engine = engine_
def _set_current_step(self, epoch: int):
"""Sets current step number.
:param epoch: Step number to be set
:type epoch: int
"""
self._cur_step = epoch * self._steps_per_epoch
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
"""Call timer funciton with a given timer name.
:param action: Function to be called on timer
:type action: str
:param item: Name of the timer
:type item: str
"""
if self._timer is not None:
getattr(self._timer, action)(item, *args, **kwargs)
def _reset_states(self) -> None:
"""Clear trainer states
"""
self.states = dict()
def _call_hooks(self, func, output=None):
"""Calls specific hooks in the current time point.
:param func: A string represents the time point
@@ -95,161 +133,186 @@ class Trainer:
else:
getattr(hook, func)(*output)
def exceed_max_step(self):
"""Checks whether the trainer exceeds the maximum number of runnning iterations.
@staticmethod
def _should_display_progress(display_progress: bool):
""" Only display progress on DP rank 0, TP rank 0 and PP last rank
"""
return self._cur_step >= self._max_steps
return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
def set_epoch(self, epoch):
"""Sets current epoch number.
:param epoch: Epoch number to be set
:type epoch: int
"""
self._cur_epoch = epoch
def _recover_steps(self):
step = self.cur_step * self._engine.schedule.num_steps
self._cur_step = step
def _set_display_progress(self, display_progress: bool):
self._display_progress = display_progress and is_dp_rank_0(
) and is_tp_rank_0() and is_no_pp_or_last_stage()
def _train_epoch(self, epoch: int = None):
def _train_epoch(self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False):
# set sampler epoch
if epoch is not None and \
hasattr(self._engine.train_dataloader, 'sampler') and \
isinstance(self._engine.train_dataloader.sampler, DataParallelSampler):
self._engine.train_dataloader.sampler.set_epoch(epoch)
hasattr(train_dataloader, 'sampler') and \
isinstance(train_dataloader.sampler, DataParallelSampler):
train_dataloader.sampler.set_epoch(epoch)
# set training state
self._engine.train()
progress = range(self._engine.schedule.num_steps)
if self._display_progress:
data_iter = iter(train_dataloader)
progress = range(self._steps_per_epoch)
if display_progress:
if epoch is None:
progress = tqdm(progress, desc='[Train]')
else:
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
# train 1 epoch
self.call_hooks('before_train_epoch')
self._timer.start('train-epoch')
for _ in progress:
self._call_hooks('before_train_epoch')
self._call_timer(action='start', item='train-epoch')
for i in progress:
self._call_hooks('before_train_iter')
self._call_timer(action='start', item='train-step')
if i == self._steps_per_epoch - 1:
is_last_iteration = True
else:
is_last_iteration = False
# run 1 training step
logits, label, loss = self._engine.step(data_iter, is_last_iteration)
self._call_timer(action='stop', item='train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss))
self._cur_step += 1
self.call_hooks('before_train_iter')
self._timer.start('train-step')
logits, label, loss = self._engine.step()
self._timer.stop('train-step', keep_in_history=True)
self.call_hooks('after_train_iter', output=(logits, label, loss))
if self.exceed_max_step():
# stop when max iter is reached
# stop when max iter is reached
if self._exceed_max_step():
break
self._timer.stop('train-epoch', keep_in_history=True)
self.call_hooks('after_train_epoch')
self._timer.reset('train-step')
self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
self._call_hooks('after_train_epoch')
self._call_timer(action='reset', item='train-step')
def _eval(self,
test_dataloader: DataLoader,
epoch: int = None,
return_loss: bool = True):
display_progress: bool = False):
# switch engine status
self._engine.eval()
self.call_hooks('before_test')
data_iter = iter(test_dataloader)
num_steps = len(test_dataloader)
self._call_hooks('before_test')
with torch.no_grad():
# prepare progress bar
progress = range(self._engine.schedule.num_steps)
if self._display_progress:
progress = range(num_steps)
if display_progress:
desc = 'Evaluation'
if epoch is not None:
desc = '[Epoch %d val]' % epoch
progress = tqdm(progress, desc=desc)
self.call_hooks('before_test_epoch')
self._timer.start('test-epoch')
self._call_hooks('before_test_epoch')
self._call_timer(action='start', item='test-epoch')
for _ in progress:
self.call_hooks('before_test_iter')
self._timer.start('test-step')
logits, label, loss = self._engine.step(
return_loss=return_loss)
self._timer.stop('test-step', keep_in_history=True)
self.call_hooks('after_test_iter',
output=(logits, label, loss))
self._timer.stop('test-epoch', keep_in_history=True)
self.call_hooks('after_test_epoch')
self.call_hooks('after_test')
self._timer.reset('test-step')
self._timer.reset('test-epoch')
self._call_hooks('before_test_iter')
self._call_timer(action='start', item='test-step')
logits, label, loss = self._engine.step(data_iter, return_loss=True)
self._call_timer(action='stop', item='test-step', keep_in_history=True)
self._call_hooks('after_test_iter',
output=(logits, label, loss))
self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
self._call_hooks('after_test_epoch')
self._call_hooks('after_test')
self._call_timer(action='reset', item='test-step')
self._call_timer(action='reset', item='test-epoch')
def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step > self._max_steps
def fit(self,
train_dataloader: DataLoader,
test_dataloader: DataLoader = None,
max_epochs: int = None,
epochs: int,
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
display_progress: bool = False):
hooks_cfg: dict = None,
display_progress: bool = False,
):
"""Trains the model to fit training data.
:param train_dataloader: DataLoader in training
:param test_dataloader: DataLoader in testing
:param max_epochs: Maximum number of epoches
:param epochs: Maximum number of epoches
:param max_steps: Maximum number of running iterations
:param test_dataloader: DataLoader in testing
:param test_interval: Interval of testing
:param hooks_cfg: A list of hook configuration
:param display_progress: If True, the training progress will be printed
:type train_dataloader: DataLoader
:type test_dataloader: DataLoader
:type max_epochs: int
:type epochs: int
:type max_steps: int
:type test_dataloader: DataLoader
:type test_interval: int
:type hooks_cfg: dict
:type display_progress: bool
:type gradient_accumulation: int
"""
# prepare dataloaders
self._train_dataloader = train_dataloader
self._engine.set_dataloader(self._train_dataloader, train=True)
self._engine.train()
# set epochs and steps, consider gradient accumulation
self._steps_per_epoch = len(train_dataloader) // self._engine.gradient_accumulation
self._max_steps = max_steps
self._max_epochs = epochs
# check if testing is required
should_test = False
if test_dataloader is not None:
self._test_dataloader = test_dataloader
self._engine.set_dataloader(self._test_dataloader, train=False)
should_test = True
# decide the
if max_epochs is not None:
self._max_epochs = max_epochs
if max_steps is not None:
self._max_steps = max_steps
self._set_display_progress(display_progress)
display_progress = self._should_display_progress(display_progress)
# reset hooks
self._reset_states()
self.hooks = list()
# build hooks
if hooks_cfg is not None:
for cfg in hooks_cfg:
hook = build_hooks(cfg, self)
self.hooks.append(hook)
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'build {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function")
# start train
self.call_hooks('before_train')
self._engine.train()
self._call_hooks('before_train')
# recover step value if resuming training
if self.cur_epoch != 0:
self._recover_steps()
last_epoch = self._cur_epoch
if self.cur_epoch != 0:
self._set_current_step(last_epoch)
for epoch in range(last_epoch, self._max_epochs):
self._cur_epoch += 1
for epoch in range(last_epoch, epochs):
# train for one epoch
self._train_epoch(epoch)
self._train_epoch(
train_dataloader=train_dataloader,
epoch=epoch,
display_progress=display_progress
)
# start eval
if should_test and epoch % test_interval == 0:
self._eval(epoch, return_loss=True)
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
epoch=epoch,
)
self._cur_epoch += 1
# check for termination
if self.exceed_max_step():
if self._exceed_max_step():
self._logger.info(
f"Max number of steps {self._max_steps} has been reached, training is stopped automatically")
f"Max number of steps {max_steps} has been reached, training is stopped automatically")
break
self.call_hooks('after_train')
self._timer.reset('train-epoch')
self._call_hooks('after_train')
self._call_timer('reset', 'train-epoch')
def evaluate(self,
test_dataloader: DataLoader,
@@ -261,15 +324,13 @@ class Trainer:
:type test_dataloader: DataLoader
:type display_progress: bool, optional
"""
# set dataloader
self._test_dataloader = test_dataloader
self._engine.set_dataloader(self._test_dataloader, train=True)
# set
self._set_display_progress(display_progress)
# set display
display_progress = self._should_display_progress(display_progress)
# eval
self._eval(return_loss=True)
self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
)
def predict(self, data: Union[Tensor, List[Tensor]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
@@ -289,45 +350,6 @@ class Trainer:
# prepare a list of (data, label) to make it iterable
# for compatibility with schedule
simple_dataloader = [(data, None)]
self._engine.set_dataloader(simple_dataloader)
output, _, _ = self._engine.step(return_loss=False)
data_iter = iter(simple_dataloader)
output, _, _ = self._engine.step(data_iter, return_loss=False)
return output
def save(self, path: str, suffix: str = ''):
"""Saves the model to a file.
:param path: Relative path of the file
:param suffix: Suffix of the file
:type path: str
:type suffix: str, optional
"""
save_path = get_checkpoint_path(path,
self._cur_epoch,
suffix=suffix)
save_checkpoint(save_path, self._cur_epoch, self._engine.get_model(),
self._engine.get_optimizer(),
self._engine.get_lr_scheduler())
def load(self,
path: str,
finetune: bool = False,
strict: bool = False):
"""Loads parameters to the model from a file.
:param path: Relative path of the file
:param finetune: Whether allows to load a part of the model
:param strict: Whether loads a model that has the same shape of parameters
:type path: str
:type finetune: bool, optional
:type strict: bool, optional
"""
last_epoch, _ = load_checkpoint(path,
self._engine.get_model(),
self._engine.get_optimizer(),
self._engine.get_lr_scheduler(),
finetune=finetune,
strict=strict)
if finetune:
self.set_epoch(0)
else:
self.set_epoch(last_epoch)

View File

@@ -2,10 +2,12 @@ from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
from ._lr_scheduler_hook import LRSchedulerHook
__all__ = [
'BaseHook', 'MetricHook',
'LoadCheckpointHook', 'SaveCheckpointHook',
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
'LRSchedulerHook'
]

View File

@@ -3,13 +3,13 @@
import os.path as osp
import torch.distributed as dist
from colossalai.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
from colossalai.registry import HOOKS
from colossalai.trainer.hooks import BaseHook
from colossalai.trainer import Trainer
from colossalai.trainer.hooks import BaseHook
from colossalai.utils import is_dp_rank_0
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
from colossalai.utils.checkpointing import save_checkpoint, load_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook
@HOOKS.register_module
@@ -33,7 +33,7 @@ class SaveCheckpointHook(BaseHook):
interval: int = 1,
checkpoint_dir: str = None,
suffix: str = '',
priority: int = 0):
priority: int = 10):
super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
@@ -41,6 +41,16 @@ class SaveCheckpointHook(BaseHook):
self.checkpoint_dir = checkpoint_dir
self.suffix = suffix
# get lr scheduler from the LRSchedulerHook before train
self._lr_scheduler = None
def before_train(self):
# check if lr scheduler is present in LRSchedulerHook
for hook in self.trainer.hooks:
if isinstance(hook, LRSchedulerHook):
self._lr_scheduler = hook.lr_scheduler
break
def after_train_epoch(self):
"""Saves the model after a training epoch.
"""
@@ -48,14 +58,18 @@ class SaveCheckpointHook(BaseHook):
if self.trainer.cur_epoch % self.interval == 0:
# only gpus with data parallel rank equals to 0 write to the disk
if is_dp_rank_0():
self.trainer.save(path=self.checkpoint_dir, suffix=self.suffix)
save_path = get_checkpoint_path(self.checkpoint_dir,
self.trainer.cur_epoch,
suffix=self.suffix)
save_checkpoint(save_path,
self.trainer.cur_epoch,
self.trainer.engine.model,
self.trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
# wait until everyone is done
if dist.is_initialized():
dist.barrier()
@HOOKS.register_module
class LoadCheckpointHook(BaseHook):
@@ -81,30 +95,46 @@ class LoadCheckpointHook(BaseHook):
epoch: int = -1,
finetune: bool = False,
strict: bool = False,
priority: int = 10) -> None:
suffix: str = '',
priority: int = 0) -> None:
super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
self.epoch = epoch
self.checkpoint_dir = checkpoint_dir
self.finetune = finetune
self.suffix = suffix
self.strict = strict
super().__init__(trainer=trainer, priority=priority)
def before_train(self):
"""Loads parameters to the model before training.
"""
# check if lr scheduler is present in LRSchedulerHook
lr_scheduler = None
for hook in self.trainer.hooks:
if isinstance(hook, LRSchedulerHook):
lr_scheduler = hook.lr_scheduler
break
# use latest checkpoint if epoch = -1
if self.epoch == -1:
path = get_latest_checkpoint_path(self.checkpoint_dir)
path = get_latest_checkpoint_path(self.checkpoint_dir, suffix=self.suffix)
else:
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch)
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch, suffix=self.suffix)
if osp.exists(path):
self.trainer.load(
path, finetune=self.finetune, strict=self.strict)
last_epoch, _ = load_checkpoint(path,
self.trainer.engine.model,
self.trainer.engine.optimizer,
lr_scheduler,
finetune=self.finetune,
strict=self.strict)
if self.finetune:
self.trainer.cur_epoch = 0
else:
self.trainer.cur_epoch = last_epoch
self.logger.info(
f'loaded checkpoint from {path}')
else:
raise FileNotFoundError(f'checkpoint is not found at {path}')
# Some utilities want to load a checkpoint without distributed being initialized
if dist.is_initialized():
dist.barrier()

View File

@@ -5,7 +5,7 @@ import os
import os.path as osp
import torch
from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
@@ -13,7 +13,7 @@ from colossalai.registry import HOOKS
from colossalai.trainer._trainer import Trainer
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage
from ._metric_hook import MetricHook
from ._base_hook import BaseHook
def _format_number(val):
@@ -24,7 +24,7 @@ def _format_number(val):
return val
class EpochIntervalHook(MetricHook):
class EpochIntervalHook(BaseHook):
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
super().__init__(trainer, priority)
self._interval = interval
@@ -45,7 +45,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None:
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
@@ -74,7 +74,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
@HOOKS.register_module
class TensorboardHook(MetricHook):
class TensorboardHook(BaseHook):
"""Specialized Hook to record the metric to Tensorboard.
:param trainer: Trainer attached with current hook
@@ -85,59 +85,71 @@ class TensorboardHook(MetricHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None:
def __init__(self,
trainer: Trainer,
log_dir: str,
dp_rank_0_only: bool = True,
tp_rank_0_only: bool = True,
priority: int = 10,
) -> None:
super().__init__(trainer=trainer, priority=priority)
self._is_rank_to_log = is_no_pp_or_last_stage()
if self._is_rank_to_log:
# create log dir
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
os.makedirs(log_dir, exist_ok=True)
# determine the ranks to generate tensorboard logs
self._is_valid_rank_to_log = is_no_pp_or_last_stage()
if dp_rank_0_only:
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0()
if tp_rank_0_only:
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0()
if self._is_valid_rank_to_log:
# create workspace on only one rank
if gpc.is_initialized(ParallelMode.GLOBAL):
rank = gpc.get_global_rank()
else:
rank = 0
log_dir = osp.join(log_dir, f'rank_{rank}')
# create workspace
if not osp.exists(log_dir):
os.makedirs(log_dir)
log_dir = osp.join(log_dir, f'rank_{rank}')
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
def after_train_iter(self, *args):
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
def _log_by_iter(self, mode: str):
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_rank_to_log:
self.writer.add_scalar(
f'{metric_name}/train', val, self.trainer.cur_step)
def after_test_iter(self, *args):
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/test', val,
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
self.trainer.cur_step)
def after_test_epoch(self):
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
def _log_by_epoch(self, mode: str):
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/test', val,
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
self.trainer.cur_step)
def after_test_iter(self, *args):
self._log_by_iter(mode='test')
def after_test_epoch(self):
self._log_by_epoch(mode='test')
def after_train_iter(self, *args):
self._log_by_iter(mode='train')
def after_train_epoch(self):
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/train', val,
self.trainer.cur_step)
self._log_by_epoch(mode='train')
@HOOKS.register_module
@@ -157,7 +169,7 @@ class LogTimingByEpochHook(EpochIntervalHook):
def __init__(self,
trainer: Trainer,
interval: int = 1,
priority: int = 1,
priority: int = 10,
log_eval: bool = True
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
@@ -217,7 +229,7 @@ class LogMemoryByEpochHook(EpochIntervalHook):
def __init__(self,
trainer: Trainer,
interval: int = 1,
priority: int = 1,
priority: int = 10,
log_eval: bool = True
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)

View File

@@ -0,0 +1,58 @@
from torch import Tensor
from colossalai.builder import build_lr_scheduler
from colossalai.registry import HOOKS
from ._metric_hook import MetricHook
from .._trainer import Trainer
from ..metric import LearningRate
@HOOKS.register_module
class LRSchedulerHook(MetricHook):
"""Build LR scheduler
:param trainer: Trainer attached with current hook
:type trainer: Trainer
:param lr_scheduler_cfg: The config of LR scheduler
:type lr_scheduler_cfg: dict
:param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. Defaults to `True`.
:type by_epoch: bool
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
"""
def __init__(self,
trainer: Trainer,
lr_scheduler_cfg: dict,
by_epoch: bool = True,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(trainer=trainer, priority=priority)
self.by_epoch = by_epoch
if by_epoch:
total_steps = trainer.max_epochs
else:
total_steps = trainer.max_epochs * trainer.steps_per_epoch
if trainer.max_steps is not None:
total_steps = min(total_steps, trainer.max_steps)
lr_scheduler_cfg['total_steps'] = total_steps
self.lr_scheduler = build_lr_scheduler(
lr_scheduler_cfg, trainer.engine.optimizer)
if store_lr_in_state:
self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch,
initial_lr=self.lr_scheduler.get_lr()[0])
def after_train_epoch(self):
if self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])

View File

@@ -21,9 +21,12 @@ class MetricHook(BaseHook):
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int):
def __init__(self,
trainer: Trainer,
priority: int,
):
super().__init__(trainer, priority)
self._is_stage_to_log = is_no_pp_or_last_stage()
self._is_stage_to_compute = is_no_pp_or_last_stage()
self._check_metric_states_initialization()
def _check_metric_states_initialization(self):
@@ -41,33 +44,34 @@ class LossHook(MetricHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, priority: int = 10):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
if self._is_stage_to_log:
self.metric = Loss(epoch_only=False)
if self._is_stage_to_compute:
self.train_loss = Loss(epoch_only=False)
self.test_loss = Loss(epoch_only=True)
# register the metric calculator
self.trainer.states['metrics']['train'][
self.metric.__class__.__name__] = self.metric
self.train_loss.__class__.__name__] = self.train_loss
self.trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
self.test_loss.__class__.__name__] = self.test_loss
def before_train_epoch(self):
if self._is_stage_to_log:
self.metric.reset()
if self._is_stage_to_compute:
self.train_loss.reset()
def after_train_iter(self, logits, label, loss):
if self._is_stage_to_log:
self.metric.update(loss)
if self._is_stage_to_compute:
self.train_loss.update(loss)
def before_test_epoch(self):
if self._is_stage_to_log:
self.metric.reset()
if self._is_stage_to_compute:
self.test_loss.reset()
def after_test_iter(self, logits, label, loss):
if self._is_stage_to_log:
self.metric.update(loss)
if self._is_stage_to_compute:
self.test_loss.update(loss)
@HOOKS.register_module
@@ -81,10 +85,10 @@ class Accuracy2DHook(MetricHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, priority: int = 10):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric = Accuracy2D(epoch_only=True)
# register the metric
@@ -92,20 +96,20 @@ class Accuracy2DHook(MetricHook):
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy2p5DHook(MetricHook):
def __init__(self, trainer: Trainer, priority: int = 10):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric = Accuracy2p5D(epoch_only=True)
# register the metric
@@ -113,11 +117,11 @@ class Accuracy2p5DHook(MetricHook):
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.update(logits, label)
@@ -138,7 +142,7 @@ class Accuracy3DHook(MetricHook):
priority: int = 10):
super().__init__(trainer, priority)
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric = Accuracy3D(epoch_only=True,
input_parallel_mode=input_parallel_mode,
weight_parallel_mode=weight_parallel_mode)
@@ -148,11 +152,11 @@ class Accuracy3DHook(MetricHook):
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.update(logits, label)
@@ -166,10 +170,10 @@ class AccuracyHook(MetricHook):
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int = 10):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric = Accuracy(epoch_only=True)
# register the metric
@@ -177,9 +181,9 @@ class AccuracyHook(MetricHook):
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
if self._is_stage_to_compute:
self.metric.update(logits, label)

View File

@@ -126,6 +126,33 @@ class Loss(Metric):
return a < b
class LearningRate(Metric):
"""A metric collector for learning rate.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
super().__init__(epoch_only=epoch_only)
self.lr = 0.
def reset(self) -> None:
pass
def update(self, lr) -> None:
self.lr = lr
def get_last_step_value(self):
return self.lr
def get_accumulated_value(self):
return self.lr
def is_better(a, b) -> bool:
pass
class Accuracy(Metric):
"""A metric collector for accuracy. It only works for classification
tasks.