[legacy] move trainer to legacy (#4545)

* [legacy] move trainer to legacy

* [doc] update docs related to trainer

* [test] ignore legacy test
This commit is contained in:
Hongxin Liu
2023-08-31 13:51:28 +08:00
parent 807e01a4ba
commit 89fe027787
32 changed files with 63 additions and 153 deletions

View File

@@ -0,0 +1,3 @@
from ._trainer import Trainer
__all__ = ['Trainer']

View File

@@ -0,0 +1,408 @@
from typing import Any, List, Union
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from colossalai.engine import Engine
from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import DistributedLogger
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
class Trainer:
r"""This is a class tending for easy deployments of users' training and evaluation instead of
writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is
called `Trainer`.
Args:
engine (:class:`Engine`): Engine responsible for the process function.
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
Examples:
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
>>> model = ...
>>> criterion = ...
>>> optimizer = ...
>>> train_dataloader = ...
>>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler
>>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion)
>>> # Beginning training progress
>>> timer = ...
>>> logger = ...
>>> trainer = Trainer(engine=engine, logger=logger, timer=timer)
>>> # add hooks you would like to use here.
>>> hook_list = []
>>> trainer.fit(
>>> train_dataloader=train_dataloader,
>>> epochs=gpc.config.NUM_EPOCHS,
>>> test_interval=1,
>>> hooks=hook_list,
>>> display_progress=True,
>>> return_output_label=False
>>> )
More examples and details could be found in
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
"""
def __init__(
self,
engine: Engine,
timer: MultiTimer = None,
logger: DistributedLogger = None,
):
# training-related params
self._engine = engine
self._max_epochs = 0
self._cur_epoch = 0
self._max_steps = 0
self._cur_step = 0
self._steps_per_epoch = 0
# misc params
self._logger = logger
self._verbose = logger is not None
# hooks can store states in this dict, and could be consumed by other hooks
self.states = dict()
# build hooks
self.hooks = list()
# multi-timer for time benchmarking
self._timer = timer
@property
def cur_epoch(self):
"""Returns the index of the current epoch."""
return self._cur_epoch
@cur_epoch.setter
def cur_epoch(self, epoch: int):
"""Set how many epochs have been processed."""
# allow setter for training resumption
self._cur_epoch = epoch
@property
def cur_step(self):
"""Returns how many iteration steps have been processed."""
return self._cur_step
@property
def max_epochs(self):
return self._max_epochs
@property
def max_steps(self):
return self._max_steps
@property
def steps_per_epoch(self):
return self._steps_per_epoch
@property
def engine(self):
return self._engine
def _set_current_step(self, epoch: int):
"""Sets current step number.
Args:
epoch (int): Step number to be set.
"""
self._cur_step = epoch * self._steps_per_epoch
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
"""Call timer function with a given timer name.
Args:
action (str): Function to be called on timer.
item (str): Name of the timer.
args (list): args used for action function.
kwargs (dict): kwargs used for action function.
"""
if self._timer is not None:
getattr(self._timer, action)(item, *args, **kwargs)
def _reset_states(self) -> None:
"""Clear trainer states"""
self.states = dict()
def _call_hooks(self, func, output=None):
"""Calls specific hooks in the current time point.
Args:
func (str): A string represents the time point.
output (Any, optional): Output of the model after running an iteration or None in any other time points.
"""
# Only after iter hook will receive output
for hook in self.hooks:
if output is None:
getattr(hook, func)(self)
else:
getattr(hook, func)(self, *output)
@staticmethod
def _should_display_progress(display_progress: bool):
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())
def _train_epoch(
self,
train_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
):
# set training state
self._engine.train()
data_iter = iter(train_dataloader)
progress = range(self._steps_per_epoch)
if display_progress:
if epoch is None:
progress = tqdm(progress, desc="[Train]")
else:
progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]")
self._call_hooks("before_train_epoch")
self._call_timer(action="start", item="Train-epoch")
for i in progress:
self._call_hooks("before_train_iter")
self._call_timer(action="start", item="Train-step")
# run 1 training step
self.engine.zero_grad()
logits, label, loss = self.engine.execute_schedule(
data_iter,
forward_only=False,
return_loss=True,
return_output_label=return_output_label,
)
self.engine.step()
self._call_timer(action="stop", item="Train-step", keep_in_history=True)
self._call_hooks("after_train_iter", output=(logits, label, loss))
self._cur_step += 1
if display_progress:
if "step_metrics" in self.states:
progress.set_postfix(**self.states["step_metrics"])
# stop when max iter is reached
if self._exceed_max_step():
break
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
self._call_hooks("after_train_epoch")
self._call_timer(action="reset", item="Train-epoch")
def _eval(
self,
test_dataloader: DataLoader,
epoch: int = None,
display_progress: bool = False,
return_output_label: bool = True,
):
# switch engine status
self._engine.eval()
data_iter = iter(test_dataloader)
num_steps = len(test_dataloader)
self._call_hooks("before_test")
# prepare progress bar
progress = range(num_steps)
if display_progress:
desc = "Evaluation"
if epoch is not None:
desc = "[Epoch %d / Test]" % epoch
progress = tqdm(progress, desc=desc)
self._call_hooks("before_test_epoch")
self._call_timer(action="start", item="Test-epoch")
with torch.no_grad():
for _ in progress:
self._call_hooks("before_test_iter")
self._call_timer(action="start", item="Test-step")
logits, label, loss = self.engine.execute_schedule(
data_iter,
forward_only=True,
return_loss=True,
return_output_label=return_output_label,
)
self._call_timer(action="stop", item="Test-step", keep_in_history=True)
self._call_hooks("after_test_iter", output=(logits, label, loss))
if display_progress:
if "step_metrics" in self.states:
progress.set_postfix(**self.states["step_metrics"])
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
self._call_hooks("after_test_epoch")
self._call_hooks("after_test")
self._call_timer(action="reset", item="Test-step")
self._call_timer(action="reset", item="Test-epoch")
def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step >= self._max_steps
def fit(
self,
train_dataloader: DataLoader,
epochs: int,
max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
r"""Trains the model to fit training data.
Args:
train_dataloader (:class:`torch.utils.data.DataLoader`): DataLoader for training.
epochs (int): Maximum number of epochs.
max_steps (int, optional): Maximum number of running iterations.
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): DataLoader for validation.
test_interval (int, optional): Interval of validation
hooks (list[BaseHook], optional): A list of hooks used in training.
display_progress (bool, optional): If True, a progress bar will be displayed.
"""
# set epochs and steps, consider gradient accumulation
self._steps_per_epoch = len(train_dataloader)
self._max_steps = max_steps
self._max_epochs = epochs
# check if testing is required
should_test = False
if test_dataloader is not None:
should_test = True
display_progress = self._should_display_progress(display_progress)
# reset hooks
self._reset_states()
if hooks is not None:
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
for hook in hooks:
assert isinstance(hook, BaseHook), \
f'expected the hook to be of type BaseHook, but got {type(hook)}'
else:
hooks = []
self.hooks = hooks
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0],
)
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
self._call_hooks("after_hook_is_attached")
self._engine.train()
self._call_hooks("before_train")
# recover step value if resuming training
last_epoch = self._cur_epoch
if self.cur_epoch != 0:
self._set_current_step(last_epoch)
for epoch in range(last_epoch, epochs):
# train for one epoch
self._train_epoch(
train_dataloader=train_dataloader,
epoch=epoch,
display_progress=display_progress,
return_output_label=return_output_label,
)
# start eval
if should_test and epoch % test_interval == 0:
self._eval(
test_dataloader=test_dataloader,
display_progress=display_progress,
epoch=epoch,
return_output_label=return_output_label,
)
self._cur_epoch += 1
# check for termination
if self._exceed_max_step():
self._logger.info(
f"Max number of steps {max_steps} has been reached, training is stopped automatically",
ranks=[0],
)
break
self._call_hooks("after_train")
self._call_timer("reset", "Train-epoch")
def evaluate(
self,
test_dataloader: DataLoader,
hooks: List[BaseHook] = None,
display_progress: bool = False,
return_output_label: bool = True,
):
"""Evaluates the model with testing data.
Args:
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
hooks (list, optional): A list of hooks used in evaluation. Defaults to None.
display_progress (bool, optional): If True, the evaluation progress will be printed. Defaults to False.
return_output_label (bool, optional): If True, the output of model and the label
will be returned. Defaults to True.
"""
# set display
display_progress = self._should_display_progress(display_progress)
# reset hooks
self._reset_states()
if hooks is not None:
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
else:
hooks = []
self.hooks = hooks
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
ranks=[0],
)
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
self._call_hooks("after_hook_is_attached")
# eval
self._eval(
test_dataloader=test_dataloader,
display_progress=display_progress,
return_output_label=return_output_label,
)
def predict(self, data: Union[Any, List[Any]]):
"""Uses trained model to make a prediction for a tensor or a tensor list.
Args:
data (Union[:class:`torch.tensor`, List[:class:`torch.tensor`]]): Data as the input.
Returns:
:class:`torch.tensor`: The output of model as the prediction
"""
# predict without labels
self._engine.eval()
# prepare a list of (data, label) to make it iterable
# for compatibility with schedule
simple_dataloader = [(data, None)]
data_iter = iter(simple_dataloader)
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
return output

View File

@@ -0,0 +1,17 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook
from ._log_hook import (
LogMemoryByEpochHook,
LogMetricByEpochHook,
LogMetricByStepHook,
LogTimingByEpochHook,
TensorboardHook,
)
from ._lr_scheduler_hook import LRSchedulerHook
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
__all__ = [
'BaseHook', 'MetricHook', 'LossHook', 'AccuracyHook', 'LogMetricByEpochHook', 'TensorboardHook',
'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook', 'ThroughputHook', 'LogMetricByStepHook',
'SaveCheckpointHook'
]

View File

@@ -0,0 +1,106 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC
from torch import Tensor
class BaseHook(ABC):
"""This class allows users to add desired actions in specific time points
during training or evaluation.
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int
"""
def __init__(self, priority: int) -> None:
self.priority = priority
def after_hook_is_attached(self, trainer):
"""Actions after hooks are attached to trainer.
"""
pass
def before_train(self, trainer):
"""Actions before training.
"""
pass
def after_train(self, trainer):
"""Actions after training.
"""
pass
def before_train_iter(self, trainer):
"""Actions before running a training iteration.
"""
pass
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a training iteration.
Args:
trainer (:class:`Trainer`): Trainer which is using this hook.
output (:class:`torch.Tensor`): Output of the model.
label (:class:`torch.Tensor`): Labels of the input data.
loss (:class:`torch.Tensor`): Loss between the output and input data.
"""
pass
def before_train_epoch(self, trainer):
"""Actions before starting a training epoch.
"""
pass
def after_train_epoch(self, trainer):
"""Actions after finishing a training epoch.
"""
pass
def before_test(self, trainer):
"""Actions before evaluation.
"""
pass
def after_test(self, trainer):
"""Actions after evaluation.
"""
pass
def before_test_epoch(self, trainer):
"""Actions before starting a testing epoch.
"""
pass
def after_test_epoch(self, trainer):
"""Actions after finishing a testing epoch.
"""
pass
def before_test_iter(self, trainer):
"""Actions before running a testing iteration.
"""
pass
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a testing iteration.
Args:
trainer (:class:`Trainer`): Trainer which is using this hook
output (:class:`torch.Tensor`): Output of the model
label (:class:`torch.Tensor`): Labels of the input data
loss (:class:`torch.Tensor`): Loss between the output and input data
"""
pass
def init_runner_states(self, trainer, key, val):
"""Initializes trainer's state.
Args:
trainer (:class:`Trainer`): Trainer which is using this hook
key: Key of state to be reset
val: Value of state to be reset
"""
if key not in trainer.states:
trainer.states[key] = val

View File

@@ -0,0 +1,73 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS
from colossalai.utils.checkpointing import save_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook
@HOOKS.register_module
class SaveCheckpointHook(BaseHook):
"""Saves the model by interval in training process.
Args:
interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1.
if save_by_iter is True, this arg refers to the number of iters between saving.
checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None.
model (torch.nn.Module, Optional): The model to save, defaults to None. When not passing,
'trainer.engine.model' will be used. We encourage you to pass the model in it to avoid some
unexpected bugs, especially when using **DDP**.
save_by_iter (bool, optional): Whether saving the checkpoint by iter, default to False.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(self,
interval: int = 1,
checkpoint_dir: str = None,
model: torch.nn.Module = None,
save_by_iter: bool = False,
priority: int = 10):
super().__init__(priority=priority)
self.interval = interval
self.checkpoint_dir = checkpoint_dir
self.model = model
self.save_by_iter = save_by_iter
self.logger = get_dist_logger()
# get lr scheduler from the LRSchedulerHook before train
self._lr_scheduler = None
def after_hook_is_attached(self, trainer):
# get lr scheduler if exists
for hook in trainer.hooks:
if isinstance(hook, LRSchedulerHook):
self._lr_scheduler = hook.lr_scheduler
break
self.model = self.model if self.model is not None else trainer.engine.model
def after_train_iter(self, trainer, output, label, loss):
"""Saves the model after a training iter.
"""
# save by interval
if self.save_by_iter and trainer.cur_step % self.interval == 0:
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}',
ranks=[0])
else:
pass
def after_train_epoch(self, trainer):
"""Saves the model after a training epoch.
"""
# save by interval
if trainer.cur_epoch % self.interval == 0:
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])

View File

@@ -0,0 +1,9 @@
import torch
def _format_number(val, prec=5):
if isinstance(val, float):
return f'{val:.{prec}g}'
elif torch.is_tensor(val) and torch.is_floating_point(val):
return f'{val.item():.{prec}g}'
return val

View File

@@ -0,0 +1,301 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import os.path as osp
from typing import List
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger
from colossalai.registry import HOOKS
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
from ._base_hook import BaseHook
from ._commons_ import _format_number
class LogByEpochHook(BaseHook):
"""Hook to log by epoch.
Args:
logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information.
interval (int, optional): Interval of printing log information, defaults to 1.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
defaults to 1. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(self, logger, interval: int = 1, priority: int = 1):
super().__init__(priority)
self.logger = logger
self._interval = interval
def _is_epoch_to_log(self, trainer):
return trainer.cur_epoch % self._interval == 0
@HOOKS.register_module
class LogMetricByStepHook(BaseHook):
"""Hook to log metric by step.
Args:
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(self, priority: int = 10):
super().__init__(priority)
def after_train_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
if isinstance(metric_calculator, ThroughputMetric):
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info()
else:
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
if isinstance(metric_calculator, ThroughputMetric):
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info()
else:
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
@HOOKS.register_module
class LogMetricByEpochHook(LogByEpochHook):
"""Specialized hook to record the metric to log.
Args:
logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information.
interval (int, optional): Interval of printing log information, defaults to 1.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(self, logger, interval: int = 1, priority: int = 10) -> None:
super().__init__(logger, interval, priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
def _get_str(self, trainer, mode):
msg = []
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
msg = ' | '.join(msg)
return msg
def after_train_epoch(self, trainer):
if self._is_epoch_to_log(trainer):
msg = self._get_str(trainer=trainer, mode='train')
if self._is_rank_to_log:
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}')
# f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
def after_test_epoch(self, trainer):
if self._is_epoch_to_log(trainer):
msg = self._get_str(trainer=trainer, mode='test')
if self._is_rank_to_log:
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
# f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
@HOOKS.register_module
class TensorboardHook(BaseHook):
"""Specialized hook to record the metric to Tensorboard.
Args:
log_dir (str): Directory of log.
ranks (list): Ranks of processors.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer,
defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(
self,
log_dir: str,
ranks: List = None,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
priority: int = 10,
) -> None:
super().__init__(priority=priority)
from torch.utils.tensorboard import SummaryWriter
# create log dir
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
os.makedirs(log_dir, exist_ok=True)
# determine the ranks to generate tensorboard logs
self._is_valid_rank_to_log = False
if not gpc.is_initialized(parallel_mode):
self._is_valid_rank_to_log = True
else:
local_rank = gpc.get_local_rank(parallel_mode)
if ranks is None or local_rank in ranks:
self._is_valid_rank_to_log = True
# check for
if gpc.is_initialized(ParallelMode.PIPELINE) and \
not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log:
raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group")
if self._is_valid_rank_to_log:
# create workspace on only one rank
if gpc.is_initialized(parallel_mode):
rank = gpc.get_local_rank(parallel_mode)
else:
rank = 0
# create workspace
log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}')
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}')
def _log_by_iter(self, trainer, mode: str):
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
def _log_by_epoch(self, trainer, mode: str):
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
def after_test_iter(self, trainer, *args):
self._log_by_iter(trainer, mode='test')
def after_test_epoch(self, trainer):
self._log_by_epoch(trainer, mode='test')
def after_train_iter(self, trainer, *args):
self._log_by_iter(trainer, mode='train')
def after_train_epoch(self, trainer):
self._log_by_epoch(trainer, mode='train')
@HOOKS.register_module
class LogTimingByEpochHook(LogByEpochHook):
"""Specialized hook to write timing record to log.
Args:
timer (:class:`colossalai.utils.MultiTimer`): Timer for the hook.
logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information.
interval (int, optional): Interval of printing log information, defaults to 1.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
ignore_num_train_steps (int, optional): Number of training steps to ignore, defaults to 0.
"""
def __init__(self,
timer: MultiTimer,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True,
ignore_num_train_steps: int = 0) -> None:
super().__init__(logger=logger, interval=interval, priority=priority)
self._timer = timer
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
# extra handling to avoid the unstable readings of the first
# few training steps to affect the history mean time
self._ignore_num_train_steps = ignore_num_train_steps
self._is_train_step_history_trimmed = False
def _get_message(self, mode):
msg = []
for timer_name, timer in self._timer:
if timer_name.startswith(mode):
last_elapsed_time = timer.get_elapsed_time()
if timer.has_history:
if timer_name == 'Train-step' and not self._is_train_step_history_trimmed:
timer._history = timer._history[self._ignore_num_train_steps:]
self._is_train_step_history_trimmed = True
history_mean = timer.get_history_mean()
history_sum = timer.get_history_sum()
msg.append(
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s'
)
else:
msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
msg = ' | '.join(msg)
return msg
def after_train_epoch(self, trainer):
"""Writes log after finishing a training epoch.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
msg = self._get_message('Train')
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}')
def after_test_epoch(self, trainer):
"""Writes log after finishing a testing epoch.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
msg = self._get_message('Test')
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
@HOOKS.register_module
class LogMemoryByEpochHook(LogByEpochHook):
"""Specialized Hook to write memory usage record to log.
Args:
logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information.
interval (int, optional): Interval of printing log information, defaults to 1.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 1. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
"""
def __init__(
self,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True,
report_cpu: bool = False, # no reference
) -> None:
super().__init__(logger=logger, interval=interval, priority=priority)
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
def before_train(self, trainer):
"""Resets before training.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
report_memory_usage('Before-train', self.logger)
def after_train_epoch(self, trainer):
"""Writes log after finishing a training epoch.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger)
def after_test(self, trainer):
"""Reports after testing.
"""
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger)

View File

@@ -0,0 +1,48 @@
from torch import Tensor
from colossalai.registry import HOOKS
from ._metric_hook import LearningRateMetric, MetricHook
@HOOKS.register_module
class LRSchedulerHook(MetricHook):
r"""Build LR scheduler for trainer.
Args:
lr_scheduler (:class:`colossalai.nn.lr_scheduler`): The specific LR scheduler
in range of ``colossalai.nn.lr_scheduler``, more details about ``lr_scheduler`` could be found in
`lr_scheduler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/lr_scheduler>`_.
by_epoch (bool): If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch.
store_lr_in_state (bool, optional): If `True`, store the learning rate in each state, defaults to `True`.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 1. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(
self,
lr_scheduler,
by_epoch: bool,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(priority=priority)
self.by_epoch = by_epoch
self.lr_scheduler = lr_scheduler
self.store_lr_in_state = store_lr_in_state
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch,
initial_lr=self.lr_scheduler.get_last_lr()[0])
def after_train_epoch(self, trainer):
if self.by_epoch:
self.lr_scheduler.step()
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])

View File

@@ -0,0 +1,438 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.distributed as dist
from colossalai.communication import all_reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook
from ._commons_ import _format_number
class Metric(ABC):
"""A basic class of metric collectors. It collects a specific
metric during training or evaluation and would always be used with
:class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric
collector works.
Args:
epoch_only (bool): Whether the metric only read for the full epoch.
"""
def __init__(self, epoch_only: bool):
# is the metric only read for the full epoch
self._epoch_only = epoch_only
@property
def epoch_only(self):
"""Returns :attr:`epoch_only`.
"""
return self._epoch_only
@abstractmethod
def reset(self) -> None:
"""Resets the metric to it's initial state.
By default, this is called at the start of each epoch.
"""
pass
@abstractmethod
def update(self, *args, **kwargs) -> None:
"""Updates the metric's state using the passed batch output.
By default, this is called once for each batch.
"""
pass
@abstractmethod
def get_last_step_value(self) -> float:
"""Returns the metric value in the last iteration.
"""
pass
@abstractmethod
def get_accumulated_value(self):
"""Computes the metric based on it's accumulated state.
By default, this is called at the end of each epoch.
:return: the actual quantity of interest
:rtype: Any
"""
pass
@staticmethod
@abstractmethod
def is_better(a, b) -> bool:
"""Compares a and b, and returns whether a is better than b
:return: The result of comparison
:rtype: bool
"""
pass
class LossMetric(Metric):
"""A metric collector for loss.
Args:
epoch_only (bool): Whether the metric only read for the full epoch.
"""
def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device())
self.accum_loss = torch.zeros(1, device=get_current_device())
self.count = 0
def reset(self) -> None:
"""Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.
"""
self.last_step_loss.zero_()
self.accum_loss.zero_()
self.count = 0
def update(self, loss) -> None:
"""Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss.
It expects the output has loss.
Args:
loss (:class:`torch.tensor`): Current loss of the output.
"""
# expect output to be logits, label and loss
loss_ = loss.detach()
self.last_step_loss.copy_(loss_)
self.accum_loss.add_(loss_)
self.count += 1
def get_accumulated_value(self):
"""Returns accumulated loss.
"""
if gpc.is_initialized(ParallelMode.DATA):
dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA))
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
self.accum_loss.div_(self.count)
return self.accum_loss.item()
def get_last_step_value(self) -> float:
"""Returns :attr:`last_step_loss`.
"""
return self.last_step_loss.cpu().item()
@staticmethod
def is_better(a, b):
return a < b
class LearningRateMetric(Metric):
"""A metric collector for learning rate.
Args:
epoch_only (bool): Whether the metric only read for the full epoch.
initial_lr (float, optional): Initial learning rate, defaults to 0.0.
"""
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
super().__init__(epoch_only=epoch_only)
self.lr = initial_lr
def reset(self) -> None:
pass
def update(self, lr) -> None:
self.lr = lr
def get_last_step_value(self) -> float:
return self.lr
def get_accumulated_value(self):
return self.lr
@staticmethod
def is_better(a, b) -> bool:
pass
class AccuracyMetric(Metric):
"""A metric collector for accuracy. It only works for classification
tasks.
Args:
epoch_only (bool): Whether the metric only read for the full epoch.
accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task.
"""
def __init__(self, epoch_only: bool, accuracy_func: Callable):
super().__init__(epoch_only=epoch_only)
self.acc = accuracy_func
self.last_step_sum = torch.zeros(1, device=get_current_device())
self.last_step_correct = torch.zeros(1, device=get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_current_device())
def reset(self) -> None:
self.last_step_sum.zero_()
self.last_step_correct.zero_()
self.accumulated_sum.zero_()
self.accumulated_correct.zero_()
def update(self, logits, targets, batch_size) -> None:
"""Updates last step accuracy and accumulated accuracy with current logits
and labels. It expects the output has logits and labels.
Args:
logits (:class:`torch.tensor`): The logits output of the model.
targets (:class:`torch.tensor`): Real labels of the dataset.
batch_size (int): Batch size of the task.
"""
if isinstance(logits, (list, tuple)):
logits = logits[0]
if isinstance(targets, (list, tuple)):
targets = targets[0]
# update
correct = self.acc(logits, targets)
self.last_step_sum.fill_(batch_size)
self.last_step_correct.fill_(correct)
self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct
def get_last_step_value(self) -> float:
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
return _format_number((self.last_step_correct / self.last_step_sum).cpu().item())
def get_accumulated_value(self):
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
return (self.accumulated_correct / self.accumulated_sum).item()
@staticmethod
def is_better(a, b) -> bool:
return a > b
class MetricHook(BaseHook):
"""Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and
update their states. Others are used to display and
record the metric.
Args:
priority (int): Priority in the printing, hooks with small priority will be printed in front
defaults to 1. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(
self,
priority: int,
):
super().__init__(priority)
self._is_stage_to_compute = is_no_pp_or_last_stage()
def _check_metric_states_initialization(self, trainer):
if 'metrics' not in trainer.states:
self.init_runner_states(trainer, 'metrics', dict(train={}, test={}))
@HOOKS.register_module
class LossHook(MetricHook):
"""Specialized hook class for :class:`Loss`.
Args:
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 0. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(self, priority: int = 0):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.train_loss = LossMetric(epoch_only=False)
self.test_loss = LossMetric(epoch_only=True)
# register the metric calculator
trainer.states['metrics']['train']['Loss'] = self.train_loss
trainer.states['metrics']['test']['Loss'] = self.test_loss
def before_train_epoch(self, trainer):
if self._is_stage_to_compute:
self.train_loss.reset()
def after_train_iter(self, trainer, logits, label, loss):
if self._is_stage_to_compute:
self.train_loss.update(loss)
def before_test_epoch(self, trainer):
if self._is_stage_to_compute:
self.test_loss.reset()
def after_test_iter(self, trainer, logits, label, loss):
if self._is_stage_to_compute:
self.test_loss.update(loss)
@HOOKS.register_module
class AccuracyHook(MetricHook):
"""Specialized hook class for :class:`Accuracy`.
Args:
accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 0. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def __init__(self, accuracy_func: Callable, priority: int = 0):
super().__init__(priority)
self.accuracy_func = accuracy_func
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func)
# register the metric
trainer.states['metrics']['test']['Accuracy'] = self.metric
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, logits, targets, *args):
if self._is_stage_to_compute:
batch_size = trainer.engine.schedule.batch_size
self.metric.update(logits, targets, batch_size)
class ThroughputMetric(Metric):
"""Metric for :class:`Throughput`.
Args:
epoch_only (bool): Whether the metric only read for the full epoch.
"""
def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0, use_local: bool = False):
super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps
self.cur_steps = 0
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device())
self._tflop_per_step = tflop_per_step
self._use_local = use_local
def reset(self) -> None:
# self.cur_steps = 0
self.accumulated_num_samples.zero_()
self.accumulated_used_time.zero_()
self.last_step_num_samples.zero_()
self.last_step_used_time.zero_()
def update(self, num_samples, time) -> None:
self.cur_steps += 1
self.last_step_num_samples.fill_(num_samples)
self.last_step_used_time.fill_(time)
if self.cur_steps >= self.ignored_steps:
self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time
def get_last_step_value(self) -> float:
if self._use_local:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
return sample_per_sec
def get_last_step_info(self) -> str:
if self._use_local:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
if self._tflop_per_step > 0:
tflops = _format_number(self._tflop_per_step / (self.last_step_used_time.item() + 1e-12))
return f"{sample_per_sec} sample_per_sec, {tflops} Tflops"
else:
return f"{sample_per_sec} sample_per_sec"
def get_accumulated_value(self) -> float:
self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
@staticmethod
def is_better(a, b) -> bool:
pass
@HOOKS.register_module
class ThroughputHook(MetricHook):
"""Specialized hook class for :class:`Throughput`. Hook to measure execution throughput (samples/sec).
Args:
ignored_steps (int, optional): the number of initial training steps to ignore.
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
tflop_per_step(int, optional): tera floating point operations per step.
use_local (bool, optional): Whether to use local time for throughput calculation.
"""
def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0, use_local=False):
super().__init__(priority)
self.ignored_steps = ignored_steps
self._tflop_per_step = tflop_per_step
self._use_local = use_local
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = ThroughputMetric(epoch_only=True,
ignored_steps=self.ignored_steps,
tflop_per_step=self._tflop_per_step,
use_local=self._use_local)
# register the metric
trainer.states['metrics']['train']['Throughput'] = self.metric
trainer.states['metrics']['test']['Throughput'] = self.metric
def before_train_epoch(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_train_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.engine.schedule.batch_size,
trainer._timer.get_timer('Train-step').get_elapsed_time())
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, *args):
if self._is_stage_to_compute:
self.metric.update(trainer.engine.schedule.batch_size,
trainer._timer.get_timer('Test-step').get_elapsed_time())