mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[legacy] move trainer to legacy (#4545)
* [legacy] move trainer to legacy * [doc] update docs related to trainer * [test] ignore legacy test
This commit is contained in:
0
colossalai/legacy/__init__.py
Normal file
0
colossalai/legacy/__init__.py
Normal file
3
colossalai/legacy/trainer/__init__.py
Normal file
3
colossalai/legacy/trainer/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from ._trainer import Trainer
|
||||
|
||||
__all__ = ['Trainer']
|
408
colossalai/legacy/trainer/_trainer.py
Normal file
408
colossalai/legacy/trainer/_trainer.py
Normal file
@@ -0,0 +1,408 @@
|
||||
from typing import Any, List, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.legacy.trainer.hooks import BaseHook
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0
|
||||
|
||||
|
||||
class Trainer:
|
||||
r"""This is a class tending for easy deployments of users' training and evaluation instead of
|
||||
writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is
|
||||
called `Trainer`.
|
||||
|
||||
Args:
|
||||
engine (:class:`Engine`): Engine responsible for the process function.
|
||||
timer (:class:`MultiTimer`, optional): Timer used to monitor the whole training.
|
||||
logger (:class:`colossalai.logging.DistributedLogger`, optional): Logger used to record the whole training log.
|
||||
|
||||
|
||||
Examples:
|
||||
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
||||
>>> model = ...
|
||||
>>> criterion = ...
|
||||
>>> optimizer = ...
|
||||
>>> train_dataloader = ...
|
||||
>>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler
|
||||
>>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion)
|
||||
>>> # Beginning training progress
|
||||
>>> timer = ...
|
||||
>>> logger = ...
|
||||
>>> trainer = Trainer(engine=engine, logger=logger, timer=timer)
|
||||
>>> # add hooks you would like to use here.
|
||||
>>> hook_list = []
|
||||
>>> trainer.fit(
|
||||
>>> train_dataloader=train_dataloader,
|
||||
>>> epochs=gpc.config.NUM_EPOCHS,
|
||||
>>> test_interval=1,
|
||||
>>> hooks=hook_list,
|
||||
>>> display_progress=True,
|
||||
>>> return_output_label=False
|
||||
>>> )
|
||||
|
||||
More examples and details could be found in
|
||||
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_
|
||||
and `ColossalAI-Examples <https://github.com/hpcaitech/ColossalAI-Examples/tree/main>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
timer: MultiTimer = None,
|
||||
logger: DistributedLogger = None,
|
||||
):
|
||||
# training-related params
|
||||
self._engine = engine
|
||||
self._max_epochs = 0
|
||||
self._cur_epoch = 0
|
||||
self._max_steps = 0
|
||||
self._cur_step = 0
|
||||
self._steps_per_epoch = 0
|
||||
|
||||
# misc params
|
||||
self._logger = logger
|
||||
self._verbose = logger is not None
|
||||
|
||||
# hooks can store states in this dict, and could be consumed by other hooks
|
||||
self.states = dict()
|
||||
|
||||
# build hooks
|
||||
self.hooks = list()
|
||||
|
||||
# multi-timer for time benchmarking
|
||||
self._timer = timer
|
||||
|
||||
@property
|
||||
def cur_epoch(self):
|
||||
"""Returns the index of the current epoch."""
|
||||
return self._cur_epoch
|
||||
|
||||
@cur_epoch.setter
|
||||
def cur_epoch(self, epoch: int):
|
||||
"""Set how many epochs have been processed."""
|
||||
# allow setter for training resumption
|
||||
self._cur_epoch = epoch
|
||||
|
||||
@property
|
||||
def cur_step(self):
|
||||
"""Returns how many iteration steps have been processed."""
|
||||
return self._cur_step
|
||||
|
||||
@property
|
||||
def max_epochs(self):
|
||||
return self._max_epochs
|
||||
|
||||
@property
|
||||
def max_steps(self):
|
||||
return self._max_steps
|
||||
|
||||
@property
|
||||
def steps_per_epoch(self):
|
||||
return self._steps_per_epoch
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
return self._engine
|
||||
|
||||
def _set_current_step(self, epoch: int):
|
||||
"""Sets current step number.
|
||||
|
||||
Args:
|
||||
epoch (int): Step number to be set.
|
||||
"""
|
||||
self._cur_step = epoch * self._steps_per_epoch
|
||||
|
||||
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
|
||||
"""Call timer function with a given timer name.
|
||||
|
||||
Args:
|
||||
action (str): Function to be called on timer.
|
||||
item (str): Name of the timer.
|
||||
args (list): args used for action function.
|
||||
kwargs (dict): kwargs used for action function.
|
||||
"""
|
||||
|
||||
if self._timer is not None:
|
||||
getattr(self._timer, action)(item, *args, **kwargs)
|
||||
|
||||
def _reset_states(self) -> None:
|
||||
"""Clear trainer states"""
|
||||
self.states = dict()
|
||||
|
||||
def _call_hooks(self, func, output=None):
|
||||
"""Calls specific hooks in the current time point.
|
||||
|
||||
Args:
|
||||
func (str): A string represents the time point.
|
||||
output (Any, optional): Output of the model after running an iteration or None in any other time points.
|
||||
"""
|
||||
# Only after iter hook will receive output
|
||||
for hook in self.hooks:
|
||||
if output is None:
|
||||
getattr(hook, func)(self)
|
||||
else:
|
||||
getattr(hook, func)(self, *output)
|
||||
|
||||
@staticmethod
|
||||
def _should_display_progress(display_progress: bool):
|
||||
"""Only display progress on DP rank 0, TP rank 0 and PP last rank"""
|
||||
return (display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage())
|
||||
|
||||
def _train_epoch(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
# set training state
|
||||
self._engine.train()
|
||||
data_iter = iter(train_dataloader)
|
||||
progress = range(self._steps_per_epoch)
|
||||
if display_progress:
|
||||
if epoch is None:
|
||||
progress = tqdm(progress, desc="[Train]")
|
||||
else:
|
||||
progress = tqdm(progress, desc=f"[Epoch {epoch} / Train]")
|
||||
|
||||
self._call_hooks("before_train_epoch")
|
||||
self._call_timer(action="start", item="Train-epoch")
|
||||
for i in progress:
|
||||
self._call_hooks("before_train_iter")
|
||||
self._call_timer(action="start", item="Train-step")
|
||||
|
||||
# run 1 training step
|
||||
self.engine.zero_grad()
|
||||
logits, label, loss = self.engine.execute_schedule(
|
||||
data_iter,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
self.engine.step()
|
||||
self._call_timer(action="stop", item="Train-step", keep_in_history=True)
|
||||
self._call_hooks("after_train_iter", output=(logits, label, loss))
|
||||
|
||||
self._cur_step += 1
|
||||
|
||||
if display_progress:
|
||||
if "step_metrics" in self.states:
|
||||
progress.set_postfix(**self.states["step_metrics"])
|
||||
|
||||
# stop when max iter is reached
|
||||
if self._exceed_max_step():
|
||||
break
|
||||
|
||||
self._call_timer(action="stop", item="Train-epoch", keep_in_history=True)
|
||||
self._call_hooks("after_train_epoch")
|
||||
self._call_timer(action="reset", item="Train-epoch")
|
||||
|
||||
def _eval(
|
||||
self,
|
||||
test_dataloader: DataLoader,
|
||||
epoch: int = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
# switch engine status
|
||||
self._engine.eval()
|
||||
|
||||
data_iter = iter(test_dataloader)
|
||||
num_steps = len(test_dataloader)
|
||||
|
||||
self._call_hooks("before_test")
|
||||
# prepare progress bar
|
||||
progress = range(num_steps)
|
||||
if display_progress:
|
||||
desc = "Evaluation"
|
||||
if epoch is not None:
|
||||
desc = "[Epoch %d / Test]" % epoch
|
||||
progress = tqdm(progress, desc=desc)
|
||||
|
||||
self._call_hooks("before_test_epoch")
|
||||
self._call_timer(action="start", item="Test-epoch")
|
||||
with torch.no_grad():
|
||||
for _ in progress:
|
||||
self._call_hooks("before_test_iter")
|
||||
self._call_timer(action="start", item="Test-step")
|
||||
logits, label, loss = self.engine.execute_schedule(
|
||||
data_iter,
|
||||
forward_only=True,
|
||||
return_loss=True,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
self._call_timer(action="stop", item="Test-step", keep_in_history=True)
|
||||
self._call_hooks("after_test_iter", output=(logits, label, loss))
|
||||
|
||||
if display_progress:
|
||||
if "step_metrics" in self.states:
|
||||
progress.set_postfix(**self.states["step_metrics"])
|
||||
|
||||
self._call_timer(action="stop", item="Test-epoch", keep_in_history=True)
|
||||
self._call_hooks("after_test_epoch")
|
||||
self._call_hooks("after_test")
|
||||
self._call_timer(action="reset", item="Test-step")
|
||||
self._call_timer(action="reset", item="Test-epoch")
|
||||
|
||||
def _exceed_max_step(self):
|
||||
return self._max_steps is not None and self._cur_step >= self._max_steps
|
||||
|
||||
def fit(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
epochs: int,
|
||||
max_steps: int = None,
|
||||
test_dataloader: DataLoader = None,
|
||||
test_interval: int = 1,
|
||||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
r"""Trains the model to fit training data.
|
||||
|
||||
Args:
|
||||
train_dataloader (:class:`torch.utils.data.DataLoader`): DataLoader for training.
|
||||
epochs (int): Maximum number of epochs.
|
||||
max_steps (int, optional): Maximum number of running iterations.
|
||||
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): DataLoader for validation.
|
||||
test_interval (int, optional): Interval of validation
|
||||
hooks (list[BaseHook], optional): A list of hooks used in training.
|
||||
display_progress (bool, optional): If True, a progress bar will be displayed.
|
||||
"""
|
||||
|
||||
# set epochs and steps, consider gradient accumulation
|
||||
self._steps_per_epoch = len(train_dataloader)
|
||||
self._max_steps = max_steps
|
||||
self._max_epochs = epochs
|
||||
|
||||
# check if testing is required
|
||||
should_test = False
|
||||
if test_dataloader is not None:
|
||||
should_test = True
|
||||
|
||||
display_progress = self._should_display_progress(display_progress)
|
||||
|
||||
# reset hooks
|
||||
self._reset_states()
|
||||
if hooks is not None:
|
||||
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
|
||||
for hook in hooks:
|
||||
assert isinstance(hook, BaseHook), \
|
||||
f'expected the hook to be of type BaseHook, but got {type(hook)}'
|
||||
else:
|
||||
hooks = []
|
||||
self.hooks = hooks
|
||||
self.hooks.sort(key=lambda hook: hook.priority)
|
||||
if self._verbose:
|
||||
for hook in self.hooks:
|
||||
self._logger.info(
|
||||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||
self._call_hooks("after_hook_is_attached")
|
||||
|
||||
self._engine.train()
|
||||
self._call_hooks("before_train")
|
||||
|
||||
# recover step value if resuming training
|
||||
last_epoch = self._cur_epoch
|
||||
if self.cur_epoch != 0:
|
||||
self._set_current_step(last_epoch)
|
||||
|
||||
for epoch in range(last_epoch, epochs):
|
||||
# train for one epoch
|
||||
self._train_epoch(
|
||||
train_dataloader=train_dataloader,
|
||||
epoch=epoch,
|
||||
display_progress=display_progress,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
|
||||
# start eval
|
||||
if should_test and epoch % test_interval == 0:
|
||||
self._eval(
|
||||
test_dataloader=test_dataloader,
|
||||
display_progress=display_progress,
|
||||
epoch=epoch,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
|
||||
self._cur_epoch += 1
|
||||
|
||||
# check for termination
|
||||
if self._exceed_max_step():
|
||||
self._logger.info(
|
||||
f"Max number of steps {max_steps} has been reached, training is stopped automatically",
|
||||
ranks=[0],
|
||||
)
|
||||
break
|
||||
self._call_hooks("after_train")
|
||||
self._call_timer("reset", "Train-epoch")
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
test_dataloader: DataLoader,
|
||||
hooks: List[BaseHook] = None,
|
||||
display_progress: bool = False,
|
||||
return_output_label: bool = True,
|
||||
):
|
||||
"""Evaluates the model with testing data.
|
||||
|
||||
Args:
|
||||
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
||||
hooks (list, optional): A list of hooks used in evaluation. Defaults to None.
|
||||
display_progress (bool, optional): If True, the evaluation progress will be printed. Defaults to False.
|
||||
return_output_label (bool, optional): If True, the output of model and the label
|
||||
will be returned. Defaults to True.
|
||||
"""
|
||||
# set display
|
||||
display_progress = self._should_display_progress(display_progress)
|
||||
|
||||
# reset hooks
|
||||
self._reset_states()
|
||||
if hooks is not None:
|
||||
assert isinstance(hooks, list), f"expected argument hooks be to list, but got {type(hooks)}"
|
||||
else:
|
||||
hooks = []
|
||||
self.hooks = hooks
|
||||
self.hooks.sort(key=lambda hook: hook.priority)
|
||||
if self._verbose:
|
||||
for hook in self.hooks:
|
||||
self._logger.info(
|
||||
f"Using {hook.__class__.__name__} for training, priority = {hook.priority}",
|
||||
ranks=[0],
|
||||
)
|
||||
self._logger.info("Lower value means higher priority for calling hook function", ranks=[0])
|
||||
self._call_hooks("after_hook_is_attached")
|
||||
|
||||
# eval
|
||||
self._eval(
|
||||
test_dataloader=test_dataloader,
|
||||
display_progress=display_progress,
|
||||
return_output_label=return_output_label,
|
||||
)
|
||||
|
||||
def predict(self, data: Union[Any, List[Any]]):
|
||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||
|
||||
Args:
|
||||
data (Union[:class:`torch.tensor`, List[:class:`torch.tensor`]]): Data as the input.
|
||||
|
||||
Returns:
|
||||
:class:`torch.tensor`: The output of model as the prediction
|
||||
"""
|
||||
# predict without labels
|
||||
self._engine.eval()
|
||||
|
||||
# prepare a list of (data, label) to make it iterable
|
||||
# for compatibility with schedule
|
||||
simple_dataloader = [(data, None)]
|
||||
data_iter = iter(simple_dataloader)
|
||||
output, _, _ = self.engine.execute_schedule(data_iter, forward_only=True, return_loss=False)
|
||||
return output
|
17
colossalai/legacy/trainer/hooks/__init__.py
Normal file
17
colossalai/legacy/trainer/hooks/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from ._base_hook import BaseHook
|
||||
from ._checkpoint_hook import SaveCheckpointHook
|
||||
from ._log_hook import (
|
||||
LogMemoryByEpochHook,
|
||||
LogMetricByEpochHook,
|
||||
LogMetricByStepHook,
|
||||
LogTimingByEpochHook,
|
||||
TensorboardHook,
|
||||
)
|
||||
from ._lr_scheduler_hook import LRSchedulerHook
|
||||
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
|
||||
|
||||
__all__ = [
|
||||
'BaseHook', 'MetricHook', 'LossHook', 'AccuracyHook', 'LogMetricByEpochHook', 'TensorboardHook',
|
||||
'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook', 'ThroughputHook', 'LogMetricByStepHook',
|
||||
'SaveCheckpointHook'
|
||||
]
|
106
colossalai/legacy/trainer/hooks/_base_hook.py
Normal file
106
colossalai/legacy/trainer/hooks/_base_hook.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class BaseHook(ABC):
|
||||
"""This class allows users to add desired actions in specific time points
|
||||
during training or evaluation.
|
||||
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int) -> None:
|
||||
self.priority = priority
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
"""Actions after hooks are attached to trainer.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train(self, trainer):
|
||||
"""Actions before training.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train(self, trainer):
|
||||
"""Actions after training.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
"""Actions before running a training iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
"""Actions after running a training iteration.
|
||||
|
||||
Args:
|
||||
trainer (:class:`Trainer`): Trainer which is using this hook.
|
||||
output (:class:`torch.Tensor`): Output of the model.
|
||||
label (:class:`torch.Tensor`): Labels of the input data.
|
||||
loss (:class:`torch.Tensor`): Loss between the output and input data.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
"""Actions before starting a training epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Actions after finishing a training epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_test(self, trainer):
|
||||
"""Actions before evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_test(self, trainer):
|
||||
"""Actions after evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_test_epoch(self, trainer):
|
||||
"""Actions before starting a testing epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_test_epoch(self, trainer):
|
||||
"""Actions after finishing a testing epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_test_iter(self, trainer):
|
||||
"""Actions before running a testing iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
"""Actions after running a testing iteration.
|
||||
|
||||
Args:
|
||||
trainer (:class:`Trainer`): Trainer which is using this hook
|
||||
output (:class:`torch.Tensor`): Output of the model
|
||||
label (:class:`torch.Tensor`): Labels of the input data
|
||||
loss (:class:`torch.Tensor`): Loss between the output and input data
|
||||
"""
|
||||
pass
|
||||
|
||||
def init_runner_states(self, trainer, key, val):
|
||||
"""Initializes trainer's state.
|
||||
|
||||
Args:
|
||||
trainer (:class:`Trainer`): Trainer which is using this hook
|
||||
key: Key of state to be reset
|
||||
val: Value of state to be reset
|
||||
"""
|
||||
if key not in trainer.states:
|
||||
trainer.states[key] = val
|
73
colossalai/legacy/trainer/hooks/_checkpoint_hook.py
Normal file
73
colossalai/legacy/trainer/hooks/_checkpoint_hook.py
Normal file
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.trainer.hooks import BaseHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils.checkpointing import save_checkpoint
|
||||
|
||||
from ._lr_scheduler_hook import LRSchedulerHook
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class SaveCheckpointHook(BaseHook):
|
||||
"""Saves the model by interval in training process.
|
||||
|
||||
Args:
|
||||
interval (int, optional): Number of epochs between saving the checkpoint, defaults to 1.
|
||||
if save_by_iter is True, this arg refers to the number of iters between saving.
|
||||
checkpoint_dir (str, optional): File name to save the checkpoint, defaults to None.
|
||||
model (torch.nn.Module, Optional): The model to save, defaults to None. When not passing,
|
||||
'trainer.engine.model' will be used. We encourage you to pass the model in it to avoid some
|
||||
unexpected bugs, especially when using **DDP**.
|
||||
save_by_iter (bool, optional): Whether saving the checkpoint by iter, default to False.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
|
||||
defaults to 10. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
interval: int = 1,
|
||||
checkpoint_dir: str = None,
|
||||
model: torch.nn.Module = None,
|
||||
save_by_iter: bool = False,
|
||||
priority: int = 10):
|
||||
super().__init__(priority=priority)
|
||||
self.interval = interval
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.model = model
|
||||
self.save_by_iter = save_by_iter
|
||||
self.logger = get_dist_logger()
|
||||
|
||||
# get lr scheduler from the LRSchedulerHook before train
|
||||
self._lr_scheduler = None
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
# get lr scheduler if exists
|
||||
for hook in trainer.hooks:
|
||||
if isinstance(hook, LRSchedulerHook):
|
||||
self._lr_scheduler = hook.lr_scheduler
|
||||
break
|
||||
self.model = self.model if self.model is not None else trainer.engine.model
|
||||
|
||||
def after_train_iter(self, trainer, output, label, loss):
|
||||
"""Saves the model after a training iter.
|
||||
"""
|
||||
# save by interval
|
||||
if self.save_by_iter and trainer.cur_step % self.interval == 0:
|
||||
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
|
||||
self._lr_scheduler)
|
||||
self.logger.info(f'checkpoint for iteration {trainer.cur_step} is saved to {self.checkpoint_dir}',
|
||||
ranks=[0])
|
||||
else:
|
||||
pass
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Saves the model after a training epoch.
|
||||
"""
|
||||
# save by interval
|
||||
if trainer.cur_epoch % self.interval == 0:
|
||||
save_checkpoint(self.checkpoint_dir, trainer.cur_epoch, self.model, trainer.engine.optimizer,
|
||||
self._lr_scheduler)
|
||||
self.logger.info(f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])
|
9
colossalai/legacy/trainer/hooks/_commons_.py
Normal file
9
colossalai/legacy/trainer/hooks/_commons_.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
def _format_number(val, prec=5):
|
||||
if isinstance(val, float):
|
||||
return f'{val:.{prec}g}'
|
||||
elif torch.is_tensor(val) and torch.is_floating_point(val):
|
||||
return f'{val.item():.{prec}g}'
|
||||
return val
|
301
colossalai/legacy/trainer/hooks/_log_hook.py
Normal file
301
colossalai/legacy/trainer/hooks/_log_hook.py
Normal file
@@ -0,0 +1,301 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
from ._commons_ import _format_number
|
||||
|
||||
|
||||
class LogByEpochHook(BaseHook):
|
||||
"""Hook to log by epoch.
|
||||
|
||||
Args:
|
||||
logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information.
|
||||
interval (int, optional): Interval of printing log information, defaults to 1.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
|
||||
defaults to 1. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, interval: int = 1, priority: int = 1):
|
||||
super().__init__(priority)
|
||||
self.logger = logger
|
||||
self._interval = interval
|
||||
|
||||
def _is_epoch_to_log(self, trainer):
|
||||
return trainer.cur_epoch % self._interval == 0
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByStepHook(BaseHook):
|
||||
"""Hook to log metric by step.
|
||||
|
||||
Args:
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
|
||||
defaults to 10. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
|
||||
if isinstance(metric_calculator, ThroughputMetric):
|
||||
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info()
|
||||
else:
|
||||
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
|
||||
if isinstance(metric_calculator, ThroughputMetric):
|
||||
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info()
|
||||
else:
|
||||
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByEpochHook(LogByEpochHook):
|
||||
"""Specialized hook to record the metric to log.
|
||||
|
||||
Args:
|
||||
logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information.
|
||||
interval (int, optional): Interval of printing log information, defaults to 1.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
|
||||
defaults to 10. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self, logger, interval: int = 1, priority: int = 10) -> None:
|
||||
super().__init__(logger, interval, priority)
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||
|
||||
def _get_str(self, trainer, mode):
|
||||
msg = []
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
|
||||
msg = ' | '.join(msg)
|
||||
return msg
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if self._is_epoch_to_log(trainer):
|
||||
msg = self._get_str(trainer=trainer, mode='train')
|
||||
|
||||
if self._is_rank_to_log:
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}')
|
||||
# f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
def after_test_epoch(self, trainer):
|
||||
if self._is_epoch_to_log(trainer):
|
||||
msg = self._get_str(trainer=trainer, mode='test')
|
||||
if self._is_rank_to_log:
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
|
||||
# f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class TensorboardHook(BaseHook):
|
||||
"""Specialized hook to record the metric to Tensorboard.
|
||||
|
||||
Args:
|
||||
log_dir (str): Directory of log.
|
||||
ranks (list): Ranks of processors.
|
||||
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`, optional): Parallel mode used in trainer,
|
||||
defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front,
|
||||
defaults to 10. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: str,
|
||||
ranks: List = None,
|
||||
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
|
||||
priority: int = 10,
|
||||
) -> None:
|
||||
super().__init__(priority=priority)
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# create log dir
|
||||
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# determine the ranks to generate tensorboard logs
|
||||
self._is_valid_rank_to_log = False
|
||||
if not gpc.is_initialized(parallel_mode):
|
||||
self._is_valid_rank_to_log = True
|
||||
else:
|
||||
local_rank = gpc.get_local_rank(parallel_mode)
|
||||
|
||||
if ranks is None or local_rank in ranks:
|
||||
self._is_valid_rank_to_log = True
|
||||
|
||||
# check for
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE) and \
|
||||
not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log:
|
||||
raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group")
|
||||
|
||||
if self._is_valid_rank_to_log:
|
||||
# create workspace on only one rank
|
||||
if gpc.is_initialized(parallel_mode):
|
||||
rank = gpc.get_local_rank(parallel_mode)
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
# create workspace
|
||||
log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}')
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
||||
|
||||
def _log_by_iter(self, trainer, mode: str):
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
if metric_calculator.epoch_only:
|
||||
continue
|
||||
val = metric_calculator.get_last_step_value()
|
||||
|
||||
if self._is_valid_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
|
||||
|
||||
def _log_by_epoch(self, trainer, mode: str):
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
if metric_calculator.epoch_only:
|
||||
val = metric_calculator.get_accumulated_value()
|
||||
if self._is_valid_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
self._log_by_iter(trainer, mode='test')
|
||||
|
||||
def after_test_epoch(self, trainer):
|
||||
self._log_by_epoch(trainer, mode='test')
|
||||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
self._log_by_iter(trainer, mode='train')
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
self._log_by_epoch(trainer, mode='train')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogTimingByEpochHook(LogByEpochHook):
|
||||
"""Specialized hook to write timing record to log.
|
||||
|
||||
Args:
|
||||
timer (:class:`colossalai.utils.MultiTimer`): Timer for the hook.
|
||||
logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information.
|
||||
interval (int, optional): Interval of printing log information, defaults to 1.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
|
||||
defaults to 10. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
|
||||
ignore_num_train_steps (int, optional): Number of training steps to ignore, defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
timer: MultiTimer,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
ignore_num_train_steps: int = 0) -> None:
|
||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||
self._timer = timer
|
||||
self._log_eval = log_eval
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||
|
||||
# extra handling to avoid the unstable readings of the first
|
||||
# few training steps to affect the history mean time
|
||||
self._ignore_num_train_steps = ignore_num_train_steps
|
||||
self._is_train_step_history_trimmed = False
|
||||
|
||||
def _get_message(self, mode):
|
||||
msg = []
|
||||
for timer_name, timer in self._timer:
|
||||
if timer_name.startswith(mode):
|
||||
last_elapsed_time = timer.get_elapsed_time()
|
||||
if timer.has_history:
|
||||
if timer_name == 'Train-step' and not self._is_train_step_history_trimmed:
|
||||
timer._history = timer._history[self._ignore_num_train_steps:]
|
||||
self._is_train_step_history_trimmed = True
|
||||
history_mean = timer.get_history_mean()
|
||||
history_sum = timer.get_history_sum()
|
||||
msg.append(
|
||||
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s'
|
||||
)
|
||||
else:
|
||||
msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
|
||||
|
||||
msg = ' | '.join(msg)
|
||||
return msg
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Writes log after finishing a training epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
msg = self._get_message('Train')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg} | #steps/epoch = {trainer.steps_per_epoch}')
|
||||
|
||||
def after_test_epoch(self, trainer):
|
||||
"""Writes log after finishing a testing epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
|
||||
msg = self._get_message('Test')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMemoryByEpochHook(LogByEpochHook):
|
||||
"""Specialized Hook to write memory usage record to log.
|
||||
|
||||
Args:
|
||||
logger (:class:`colossalai.logging.DistributedLogger`): Logger for recording the log information.
|
||||
interval (int, optional): Interval of printing log information, defaults to 1.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
|
||||
defaults to 1. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
report_cpu: bool = False, # no reference
|
||||
) -> None:
|
||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||
self._log_eval = log_eval
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
||||
|
||||
def before_train(self, trainer):
|
||||
"""Resets before training.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
report_memory_usage('Before-train', self.logger)
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Writes log after finishing a training epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger)
|
||||
|
||||
def after_test(self, trainer):
|
||||
"""Reports after testing.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
|
||||
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger)
|
48
colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
Normal file
48
colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.registry import HOOKS
|
||||
|
||||
from ._metric_hook import LearningRateMetric, MetricHook
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LRSchedulerHook(MetricHook):
|
||||
r"""Build LR scheduler for trainer.
|
||||
|
||||
Args:
|
||||
lr_scheduler (:class:`colossalai.nn.lr_scheduler`): The specific LR scheduler
|
||||
in range of ``colossalai.nn.lr_scheduler``, more details about ``lr_scheduler`` could be found in
|
||||
`lr_scheduler <https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/nn/lr_scheduler>`_.
|
||||
by_epoch (bool): If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch.
|
||||
store_lr_in_state (bool, optional): If `True`, store the learning rate in each state, defaults to `True`.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
|
||||
defaults to 1. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lr_scheduler,
|
||||
by_epoch: bool,
|
||||
store_lr_in_state: bool = True,
|
||||
priority: int = 1,
|
||||
):
|
||||
super().__init__(priority=priority)
|
||||
self.by_epoch = by_epoch
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.store_lr_in_state = store_lr_in_state
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch,
|
||||
initial_lr=self.lr_scheduler.get_last_lr()[0])
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if self.by_epoch:
|
||||
self.lr_scheduler.step()
|
||||
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
|
||||
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
if not self.by_epoch:
|
||||
self.lr_scheduler.step()
|
||||
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
|
438
colossalai/legacy/trainer/hooks/_metric_hook.py
Normal file
438
colossalai/legacy/trainer/hooks/_metric_hook.py
Normal file
@@ -0,0 +1,438 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.communication import all_reduce
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
from ._commons_ import _format_number
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
"""A basic class of metric collectors. It collects a specific
|
||||
metric during training or evaluation and would always be used with
|
||||
:class:`MetricHook` to help it update its states and show the
|
||||
metric. So please use corresponding hook class to make the metric
|
||||
collector works.
|
||||
|
||||
Args:
|
||||
epoch_only (bool): Whether the metric only read for the full epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool):
|
||||
# is the metric only read for the full epoch
|
||||
self._epoch_only = epoch_only
|
||||
|
||||
@property
|
||||
def epoch_only(self):
|
||||
"""Returns :attr:`epoch_only`.
|
||||
"""
|
||||
return self._epoch_only
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Resets the metric to it's initial state.
|
||||
By default, this is called at the start of each epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
"""Updates the metric's state using the passed batch output.
|
||||
By default, this is called once for each batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_last_step_value(self) -> float:
|
||||
"""Returns the metric value in the last iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_accumulated_value(self):
|
||||
"""Computes the metric based on it's accumulated state.
|
||||
By default, this is called at the end of each epoch.
|
||||
|
||||
:return: the actual quantity of interest
|
||||
:rtype: Any
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def is_better(a, b) -> bool:
|
||||
"""Compares a and b, and returns whether a is better than b
|
||||
|
||||
:return: The result of comparison
|
||||
:rtype: bool
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LossMetric(Metric):
|
||||
"""A metric collector for loss.
|
||||
|
||||
Args:
|
||||
epoch_only (bool): Whether the metric only read for the full epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_current_device())
|
||||
self.count = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.
|
||||
"""
|
||||
self.last_step_loss.zero_()
|
||||
self.accum_loss.zero_()
|
||||
self.count = 0
|
||||
|
||||
def update(self, loss) -> None:
|
||||
"""Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss.
|
||||
It expects the output has loss.
|
||||
|
||||
Args:
|
||||
loss (:class:`torch.tensor`): Current loss of the output.
|
||||
"""
|
||||
# expect output to be logits, label and loss
|
||||
loss_ = loss.detach()
|
||||
self.last_step_loss.copy_(loss_)
|
||||
self.accum_loss.add_(loss_)
|
||||
self.count += 1
|
||||
|
||||
def get_accumulated_value(self):
|
||||
"""Returns accumulated loss.
|
||||
"""
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA))
|
||||
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
|
||||
|
||||
self.accum_loss.div_(self.count)
|
||||
return self.accum_loss.item()
|
||||
|
||||
def get_last_step_value(self) -> float:
|
||||
"""Returns :attr:`last_step_loss`.
|
||||
"""
|
||||
return self.last_step_loss.cpu().item()
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b):
|
||||
return a < b
|
||||
|
||||
|
||||
class LearningRateMetric(Metric):
|
||||
"""A metric collector for learning rate.
|
||||
|
||||
Args:
|
||||
epoch_only (bool): Whether the metric only read for the full epoch.
|
||||
initial_lr (float, optional): Initial learning rate, defaults to 0.0.
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.lr = initial_lr
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def update(self, lr) -> None:
|
||||
self.lr = lr
|
||||
|
||||
def get_last_step_value(self) -> float:
|
||||
return self.lr
|
||||
|
||||
def get_accumulated_value(self):
|
||||
return self.lr
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class AccuracyMetric(Metric):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks.
|
||||
|
||||
Args:
|
||||
epoch_only (bool): Whether the metric only read for the full epoch.
|
||||
accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task.
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.acc = accuracy_func
|
||||
self.last_step_sum = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_step_sum.zero_()
|
||||
self.last_step_correct.zero_()
|
||||
self.accumulated_sum.zero_()
|
||||
self.accumulated_correct.zero_()
|
||||
|
||||
def update(self, logits, targets, batch_size) -> None:
|
||||
"""Updates last step accuracy and accumulated accuracy with current logits
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
Args:
|
||||
logits (:class:`torch.tensor`): The logits output of the model.
|
||||
targets (:class:`torch.tensor`): Real labels of the dataset.
|
||||
batch_size (int): Batch size of the task.
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(targets, (list, tuple)):
|
||||
targets = targets[0]
|
||||
# update
|
||||
correct = self.acc(logits, targets)
|
||||
|
||||
self.last_step_sum.fill_(batch_size)
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
def get_last_step_value(self) -> float:
|
||||
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
|
||||
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
|
||||
return _format_number((self.last_step_correct / self.last_step_sum).cpu().item())
|
||||
|
||||
def get_accumulated_value(self):
|
||||
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
|
||||
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
|
||||
return (self.accumulated_correct / self.accumulated_sum).item()
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
|
||||
class MetricHook(BaseHook):
|
||||
"""Specialized hook classes for :class:`Metric`.
|
||||
Some help metric collectors initialize, reset and
|
||||
update their states. Others are used to display and
|
||||
record the metric.
|
||||
|
||||
Args:
|
||||
priority (int): Priority in the printing, hooks with small priority will be printed in front
|
||||
defaults to 1. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
priority: int,
|
||||
):
|
||||
super().__init__(priority)
|
||||
self._is_stage_to_compute = is_no_pp_or_last_stage()
|
||||
|
||||
def _check_metric_states_initialization(self, trainer):
|
||||
if 'metrics' not in trainer.states:
|
||||
self.init_runner_states(trainer, 'metrics', dict(train={}, test={}))
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LossHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Loss`.
|
||||
|
||||
Args:
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
|
||||
defaults to 0. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
|
||||
if self._is_stage_to_compute:
|
||||
self.train_loss = LossMetric(epoch_only=False)
|
||||
self.test_loss = LossMetric(epoch_only=True)
|
||||
|
||||
# register the metric calculator
|
||||
trainer.states['metrics']['train']['Loss'] = self.train_loss
|
||||
trainer.states['metrics']['test']['Loss'] = self.test_loss
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.train_loss.reset()
|
||||
|
||||
def after_train_iter(self, trainer, logits, label, loss):
|
||||
if self._is_stage_to_compute:
|
||||
self.train_loss.update(loss)
|
||||
|
||||
def before_test_epoch(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.test_loss.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, loss):
|
||||
if self._is_stage_to_compute:
|
||||
self.test_loss.update(loss)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class AccuracyHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy`.
|
||||
|
||||
Args:
|
||||
accuracy_func (:class:`typing.Callable`): Accuracy function for the classification task.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
|
||||
defaults to 0. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
self.accuracy_func = accuracy_func
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test']['Accuracy'] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
if self._is_stage_to_compute:
|
||||
batch_size = trainer.engine.schedule.batch_size
|
||||
self.metric.update(logits, targets, batch_size)
|
||||
|
||||
|
||||
class ThroughputMetric(Metric):
|
||||
"""Metric for :class:`Throughput`.
|
||||
|
||||
Args:
|
||||
epoch_only (bool): Whether the metric only read for the full epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0, use_local: bool = False):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.ignored_steps = ignored_steps
|
||||
self.cur_steps = 0
|
||||
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_used_time = torch.zeros(1, device=get_current_device())
|
||||
self._tflop_per_step = tflop_per_step
|
||||
self._use_local = use_local
|
||||
|
||||
def reset(self) -> None:
|
||||
# self.cur_steps = 0
|
||||
self.accumulated_num_samples.zero_()
|
||||
self.accumulated_used_time.zero_()
|
||||
self.last_step_num_samples.zero_()
|
||||
self.last_step_used_time.zero_()
|
||||
|
||||
def update(self, num_samples, time) -> None:
|
||||
self.cur_steps += 1
|
||||
self.last_step_num_samples.fill_(num_samples)
|
||||
self.last_step_used_time.fill_(time)
|
||||
if self.cur_steps >= self.ignored_steps:
|
||||
self.accumulated_num_samples += self.last_step_num_samples
|
||||
self.accumulated_used_time += self.last_step_used_time
|
||||
|
||||
def get_last_step_value(self) -> float:
|
||||
if self._use_local:
|
||||
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
|
||||
else:
|
||||
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
|
||||
|
||||
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
|
||||
return sample_per_sec
|
||||
|
||||
def get_last_step_info(self) -> str:
|
||||
if self._use_local:
|
||||
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
|
||||
else:
|
||||
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
|
||||
|
||||
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
|
||||
if self._tflop_per_step > 0:
|
||||
tflops = _format_number(self._tflop_per_step / (self.last_step_used_time.item() + 1e-12))
|
||||
return f"{sample_per_sec} sample_per_sec, {tflops} Tflops"
|
||||
else:
|
||||
return f"{sample_per_sec} sample_per_sec"
|
||||
|
||||
def get_accumulated_value(self) -> float:
|
||||
self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
|
||||
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class ThroughputHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Throughput`. Hook to measure execution throughput (samples/sec).
|
||||
|
||||
Args:
|
||||
ignored_steps (int, optional): the number of initial training steps to ignore.
|
||||
priority (int, optional): Priority in the printing, hooks with small priority will be printed in front
|
||||
defaults to 10. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
tflop_per_step(int, optional): tera floating point operations per step.
|
||||
use_local (bool, optional): Whether to use local time for throughput calculation.
|
||||
"""
|
||||
|
||||
def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0, use_local=False):
|
||||
super().__init__(priority)
|
||||
self.ignored_steps = ignored_steps
|
||||
self._tflop_per_step = tflop_per_step
|
||||
self._use_local = use_local
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = ThroughputMetric(epoch_only=True,
|
||||
ignored_steps=self.ignored_steps,
|
||||
tflop_per_step=self._tflop_per_step,
|
||||
use_local=self._use_local)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['train']['Throughput'] = self.metric
|
||||
trainer.states['metrics']['test']['Throughput'] = self.metric
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.engine.schedule.batch_size,
|
||||
trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.engine.schedule.batch_size,
|
||||
trainer._timer.get_timer('Test-step').get_elapsed_time())
|
Reference in New Issue
Block a user