mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 01:24:04 +00:00
Migrated project
This commit is contained in:
5
colossalai/trainer/__init__.py
Normal file
5
colossalai/trainer/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ._trainer import Trainer
|
||||
from .hooks import *
|
||||
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D
|
||||
|
||||
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D']
|
333
colossalai/trainer/_trainer.py
Normal file
333
colossalai/trainer/_trainer.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.builder import build_hooks
|
||||
from colossalai.checkpointing import save_checkpoint, load_checkpoint, get_checkpoint_path
|
||||
from colossalai.context import Config
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
|
||||
from colossalai.nn.data import DataParallelSampler
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""This a class tending for easy deployments of users' training and evaluation instead of
|
||||
writing their own scripts. It is similar with ``ignite.engine`` and ``keras.engine``, but is
|
||||
called `Trainer`.
|
||||
|
||||
:param engine: Engine responsible for the process function
|
||||
:param hooks_cfg: The configuration of hooks
|
||||
:param verbose: If True, additional information will be printed
|
||||
:type engine: Engine
|
||||
:type hoooks_cfg: Config, optional
|
||||
:type verbose: bool, optional
|
||||
"""
|
||||
def __init__(self,
|
||||
engine: Engine,
|
||||
hooks_cfg: Optional[Config] = None,
|
||||
verbose: bool = False):
|
||||
# training-ralated params
|
||||
self._engine = engine
|
||||
self._max_epochs = float('inf')
|
||||
self._max_steps = float('inf')
|
||||
self._cur_epoch = 0
|
||||
self._cur_step = 0
|
||||
|
||||
# data-related params
|
||||
self._train_dataloader = None
|
||||
self._test_dataloader = None
|
||||
|
||||
# misc params
|
||||
self._display_progress = False
|
||||
self._logger = get_global_dist_logger()
|
||||
self._verbose = verbose
|
||||
|
||||
# hooks can store states in this dict, and could be consumed by other hooks
|
||||
self.states = {}
|
||||
|
||||
# build hooks
|
||||
self.hooks = list()
|
||||
if hooks_cfg is not None:
|
||||
for cfg in hooks_cfg:
|
||||
hook = build_hooks(cfg, self)
|
||||
self.hooks.append(hook)
|
||||
self.hooks.sort(key=lambda hook: hook.priority)
|
||||
if self._verbose:
|
||||
for hook in self.hooks:
|
||||
self._logger.info(
|
||||
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
|
||||
|
||||
# timer
|
||||
self._timer = get_global_multitimer()
|
||||
|
||||
@property
|
||||
def cur_epoch(self):
|
||||
"""Returns the index of the current epoch.
|
||||
"""
|
||||
return self._cur_epoch
|
||||
|
||||
@property
|
||||
def cur_step(self):
|
||||
"""Returns how many iteration steps have been processed.
|
||||
"""
|
||||
return self._cur_step
|
||||
|
||||
def call_hooks(self, func, output=None):
|
||||
"""Calls specific hooks in the current time point.
|
||||
|
||||
:param func: A string represents the time point
|
||||
:param output: Output of the model after running a iteration or None in any other time points
|
||||
:type func: str
|
||||
:type output: optional
|
||||
"""
|
||||
# Only after iter hook will receive output
|
||||
for hook in self.hooks:
|
||||
if output is None:
|
||||
getattr(hook, func)()
|
||||
else:
|
||||
getattr(hook, func)(*output)
|
||||
|
||||
def exceed_max_step(self):
|
||||
"""Checks whether the trainer exceeds the maximum number of runnning iterations.
|
||||
"""
|
||||
return self._cur_step >= self._max_steps
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
"""Sets current epoch number.
|
||||
|
||||
:param epoch: Epoch number to be set
|
||||
:type epoch: int
|
||||
"""
|
||||
self._cur_epoch = epoch
|
||||
|
||||
def _recover_steps(self):
|
||||
step = self.cur_step * self._engine.schedule.num_steps
|
||||
self._cur_step = step
|
||||
|
||||
def _set_display_progress(self, display_progress: bool):
|
||||
self._display_progress = display_progress and is_dp_rank_0(
|
||||
) and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||
|
||||
def _train_epoch(self, epoch: int = None):
|
||||
# set sampler epoch
|
||||
if epoch is not None and \
|
||||
hasattr(self._engine.train_dataloader, 'sampler') and \
|
||||
isinstance(self._engine.train_dataloader.sampler, DataParallelSampler):
|
||||
self._engine.train_dataloader.sampler.set_epoch(epoch)
|
||||
|
||||
self._engine.train()
|
||||
|
||||
progress = range(self._engine.schedule.num_steps)
|
||||
if self._display_progress:
|
||||
if epoch is None:
|
||||
progress = tqdm(progress, desc='[Train]')
|
||||
else:
|
||||
progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
|
||||
|
||||
# train 1 epoch
|
||||
self.call_hooks('before_train_epoch')
|
||||
self._timer.start('train-epoch')
|
||||
for _ in progress:
|
||||
self._cur_step += 1
|
||||
|
||||
self.call_hooks('before_train_iter')
|
||||
self._timer.start('train-step')
|
||||
logits, label, loss = self._engine.step()
|
||||
self._timer.stop('train-step', keep_in_history=True)
|
||||
self.call_hooks('after_train_iter', output=(logits, label, loss))
|
||||
|
||||
if self.exceed_max_step():
|
||||
# stop when max iter is reached
|
||||
break
|
||||
self._timer.stop('train-epoch', keep_in_history=True)
|
||||
self.call_hooks('after_train_epoch')
|
||||
self._timer.reset('train-step')
|
||||
|
||||
def _eval(self,
|
||||
epoch: int = None,
|
||||
return_loss: bool = True):
|
||||
# switch engine status
|
||||
self._engine.eval()
|
||||
|
||||
self.call_hooks('before_test')
|
||||
with torch.no_grad():
|
||||
# prepare progress bar
|
||||
progress = range(self._engine.schedule.num_steps)
|
||||
if self._display_progress:
|
||||
desc = 'Evaluation'
|
||||
if epoch is not None:
|
||||
desc = '[Epoch %d val]' % epoch
|
||||
progress = tqdm(progress, desc=desc)
|
||||
|
||||
self.call_hooks('before_test_epoch')
|
||||
self._timer.start('test-epoch')
|
||||
for _ in progress:
|
||||
self.call_hooks('before_test_iter')
|
||||
self._timer.start('test-step')
|
||||
logits, label, loss = self._engine.step(
|
||||
return_loss=return_loss)
|
||||
self._timer.stop('test-step', keep_in_history=True)
|
||||
self.call_hooks('after_test_iter',
|
||||
output=(logits, label, loss))
|
||||
self._timer.stop('test-epoch', keep_in_history=True)
|
||||
self.call_hooks('after_test_epoch')
|
||||
self.call_hooks('after_test')
|
||||
self._timer.reset('test-step')
|
||||
self._timer.reset('test-epoch')
|
||||
|
||||
def fit(self,
|
||||
train_dataloader: DataLoader,
|
||||
test_dataloader: DataLoader = None,
|
||||
max_epochs: int = None,
|
||||
max_steps: int = None,
|
||||
test_interval: int = 1,
|
||||
display_progress: bool = False):
|
||||
"""Trains the model to fit training data.
|
||||
|
||||
:param train_dataloader: DataLoader in training
|
||||
:param test_dataloader: DataLoader in testing
|
||||
:param max_epochs: Maximum number of epoches
|
||||
:param max_steps: Maximum number of running iterations
|
||||
:param test_interval: Interval of testing
|
||||
:param display_progress: If True, the training progress will be printed
|
||||
:type train_dataloader: DataLoader
|
||||
:type test_dataloader: DataLoader
|
||||
:type max_epochs: int
|
||||
:type max_steps: int
|
||||
:type test_interval: int
|
||||
:type display_progress: bool
|
||||
"""
|
||||
|
||||
# prepare dataloaders
|
||||
self._train_dataloader = train_dataloader
|
||||
self._engine.set_dataloader(self._train_dataloader, train=True)
|
||||
self._engine.train()
|
||||
|
||||
should_test = False
|
||||
if test_dataloader is not None:
|
||||
self._test_dataloader = test_dataloader
|
||||
self._engine.set_dataloader(self._test_dataloader, train=False)
|
||||
should_test = True
|
||||
|
||||
# decide the
|
||||
if max_epochs is not None:
|
||||
self._max_epochs = max_epochs
|
||||
if max_steps is not None:
|
||||
self._max_steps = max_steps
|
||||
self._set_display_progress(display_progress)
|
||||
|
||||
# start train
|
||||
self.call_hooks('before_train')
|
||||
|
||||
# recover step value if resuming training
|
||||
if self.cur_epoch != 0:
|
||||
self._recover_steps()
|
||||
|
||||
last_epoch = self._cur_epoch
|
||||
|
||||
for epoch in range(last_epoch, self._max_epochs):
|
||||
self._cur_epoch += 1
|
||||
|
||||
# train for one epoch
|
||||
self._train_epoch(epoch)
|
||||
|
||||
# start eval
|
||||
if should_test and epoch % test_interval == 0:
|
||||
self._eval(epoch, return_loss=True)
|
||||
|
||||
# check for termination
|
||||
if self.exceed_max_step():
|
||||
self._logger.info(
|
||||
f"Max number of steps {self._max_steps} has been reached, training is stopped automatically")
|
||||
break
|
||||
self.call_hooks('after_train')
|
||||
self._timer.reset('train-epoch')
|
||||
|
||||
def evaluate(self,
|
||||
test_dataloader: DataLoader,
|
||||
display_progress: bool = False):
|
||||
"""Evaluates the model with testing data.
|
||||
|
||||
:param test_dataloader: DataLoader in testing
|
||||
:param display_progress: If True, the evaluation progress will be printed
|
||||
:type test_dataloader: DataLoader
|
||||
:type display_progress: bool, optional
|
||||
"""
|
||||
# set dataloader
|
||||
self._test_dataloader = test_dataloader
|
||||
self._engine.set_dataloader(self._test_dataloader, train=True)
|
||||
|
||||
# set
|
||||
self._set_display_progress(display_progress)
|
||||
|
||||
# eval
|
||||
self._eval(return_loss=True)
|
||||
|
||||
def predict(self, data: Union[Tensor, List[Tensor]]):
|
||||
"""Uses trained model to make a prediction for a tensor or a tensor list.
|
||||
|
||||
:param data: Data as the input
|
||||
:type data: Union[Tensor, List[Tensor]
|
||||
:return: The output of model as the prediction
|
||||
:rtype: Tensor
|
||||
"""
|
||||
# predict without labels
|
||||
if isinstance(data, (list, tuple)):
|
||||
assert isinstance(data[0], Tensor)
|
||||
else:
|
||||
assert isinstance(data, Tensor)
|
||||
self._engine.eval()
|
||||
|
||||
# prepare a list of (data, label) to make it iterable
|
||||
# for compatibility with schedule
|
||||
simple_dataloader = [(data, None)]
|
||||
self._engine.set_dataloader(simple_dataloader)
|
||||
output, _, _ = self._engine.step(return_loss=False)
|
||||
return output
|
||||
|
||||
def save(self, path: str, suffix: str = ''):
|
||||
"""Saves the model to a file.
|
||||
|
||||
:param path: Relative path of the file
|
||||
:param suffix: Suffix of the file
|
||||
:type path: str
|
||||
:type suffix: str, optional
|
||||
"""
|
||||
save_path = get_checkpoint_path(path,
|
||||
self._cur_epoch,
|
||||
suffix=suffix)
|
||||
save_checkpoint(save_path, self._cur_epoch, self._engine.get_model(),
|
||||
self._engine.get_optimizer(),
|
||||
self._engine.get_lr_scheduler())
|
||||
|
||||
def load(self,
|
||||
path: str,
|
||||
finetune: bool = False,
|
||||
strict: bool = False):
|
||||
"""Loads parameters to the model from a file.
|
||||
|
||||
:param path: Relative path of the file
|
||||
:param finetune: Whether allows to load a part of the model
|
||||
:param strict: Whether loads a model that has the same shape of parameters
|
||||
:type path: str
|
||||
:type finetune: bool, optional
|
||||
:type strict: bool, optional
|
||||
"""
|
||||
last_epoch, _ = load_checkpoint(path,
|
||||
self._engine.get_model(),
|
||||
self._engine.get_optimizer(),
|
||||
self._engine.get_lr_scheduler(),
|
||||
finetune=finetune,
|
||||
strict=strict)
|
||||
if finetune:
|
||||
self.set_epoch(0)
|
||||
else:
|
||||
self.set_epoch(last_epoch)
|
11
colossalai/trainer/hooks/__init__.py
Normal file
11
colossalai/trainer/hooks/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from ._base_hook import BaseHook
|
||||
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
||||
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
|
||||
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
||||
|
||||
__all__ = [
|
||||
'BaseHook', 'MetricHook',
|
||||
'LoadCheckpointHook', 'SaveCheckpointHook',
|
||||
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
||||
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
||||
]
|
107
colossalai/trainer/hooks/_base_hook.py
Normal file
107
colossalai/trainer/hooks/_base_hook.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from .._trainer import Trainer
|
||||
|
||||
|
||||
class BaseHook(ABC):
|
||||
"""This class allows users to add desired actions in specific time points
|
||||
during training or evaluation.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
def __init__(self, trainer: Trainer, priority: int) -> None:
|
||||
self.trainer = trainer
|
||||
self.priority = priority
|
||||
self.logger = get_global_dist_logger()
|
||||
|
||||
def before_train(self):
|
||||
"""Actions before training.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train(self):
|
||||
"""Actions after training.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train_iter(self):
|
||||
"""Actions before running a training iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
|
||||
"""Actions after running a training iteration.
|
||||
|
||||
:param output: Output of the model
|
||||
:param label: Labels of the input data
|
||||
:param loss: Loss between the output and input data
|
||||
:type output: Tensor
|
||||
:type label: Tensor
|
||||
:type loss: Tensor
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train_epoch(self):
|
||||
"""Actions before starting a training epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train_epoch(self):
|
||||
"""Actions after finishing a training epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_test(self):
|
||||
"""Actions before evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_test(self):
|
||||
"""Actions after evaluation.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_test_epoch(self):
|
||||
"""Actions before starting a testing epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_test_epoch(self):
|
||||
"""Actions after finishing a testing epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_test_iter(self):
|
||||
"""Actions before running a testing iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_test_iter(self, output: Tensor, label: Tensor, loss: Tensor):
|
||||
"""Actions after running a testing iteration.
|
||||
|
||||
:param output: Output of the model
|
||||
:param label: Labels of the input data
|
||||
:param loss: Loss between the output and input data
|
||||
:type output: Tensor
|
||||
:type label: Tensor
|
||||
:type loss: Tensor
|
||||
"""
|
||||
pass
|
||||
|
||||
def init_runner_states(self, key, val):
|
||||
"""Initializes trainer's state.
|
||||
|
||||
:param key: Key of reseting state
|
||||
:param val: Value of reseting state
|
||||
"""
|
||||
if key not in self.trainer.states:
|
||||
self.trainer.states[key] = val
|
110
colossalai/trainer/hooks/_checkpoint_hook.py
Normal file
110
colossalai/trainer/hooks/_checkpoint_hook.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os.path as osp
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.trainer.hooks import BaseHook
|
||||
from colossalai.trainer import Trainer
|
||||
from colossalai.utils import is_dp_rank_0
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class SaveCheckpointHook(BaseHook):
|
||||
"""Saves the model by interval in training process.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param interval: Saving interval
|
||||
:param checkpoint_dir: Directory of saving checkpoint
|
||||
:param suffix: Saving suffix of the file
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type interval: int, optional
|
||||
:type checkpoint_dir: int, optional
|
||||
:type suffix: str, optional
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
trainer: Trainer,
|
||||
interval: int = 1,
|
||||
checkpoint_dir: str = None,
|
||||
suffix: str = '',
|
||||
priority: int = 0):
|
||||
super().__init__(trainer=trainer, priority=priority)
|
||||
assert isinstance(trainer, Trainer), \
|
||||
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
|
||||
self.interval = interval
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.suffix = suffix
|
||||
|
||||
def after_train_epoch(self):
|
||||
"""Saves the model after a training epoch.
|
||||
"""
|
||||
# save by interval
|
||||
if self.trainer.cur_epoch % self.interval == 0:
|
||||
# only gpus with data parallel rank equals to 0 write to the disk
|
||||
if is_dp_rank_0():
|
||||
self.trainer.save(path=self.checkpoint_dir, suffix=self.suffix)
|
||||
self.logger.info(
|
||||
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
|
||||
|
||||
# wait until everyone is done
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LoadCheckpointHook(BaseHook):
|
||||
"""Loads the model before training process.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param checkpoint_dir: Directory of saving checkpoint
|
||||
:param epoch: Epoch number to be set
|
||||
:param finetune: Whether allows to load a part of the model
|
||||
:param strict: Whether loads a model that has the same shape of parameters
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type checkpoint_dir: str, optional
|
||||
:type epoch: str, optional
|
||||
:type finetune: bool, optional
|
||||
:type strict: bool, optional
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
trainer: Trainer = None,
|
||||
checkpoint_dir: str = None,
|
||||
epoch: int = -1,
|
||||
finetune: bool = False,
|
||||
strict: bool = False,
|
||||
priority: int = 10) -> None:
|
||||
assert isinstance(trainer, Trainer), \
|
||||
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
|
||||
self.epoch = epoch
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.finetune = finetune
|
||||
self.strict = strict
|
||||
super().__init__(trainer=trainer, priority=priority)
|
||||
|
||||
def before_train(self):
|
||||
"""Loads parameters to the model before training.
|
||||
"""
|
||||
if self.epoch == -1:
|
||||
path = get_latest_checkpoint_path(self.checkpoint_dir)
|
||||
else:
|
||||
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch)
|
||||
if osp.exists(path):
|
||||
self.trainer.load(
|
||||
path, finetune=self.finetune, strict=self.strict)
|
||||
self.logger.info(
|
||||
f'loaded checkpoint from {path}')
|
||||
else:
|
||||
raise FileNotFoundError(f'checkpoint is not found at {path}')
|
||||
|
||||
# Some utilities want to load a checkpoint without distributed being initialized
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
247
colossalai/trainer/hooks/_log_hook.py
Normal file
247
colossalai/trainer/hooks/_log_hook.py
Normal file
@@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.trainer._trainer import Trainer
|
||||
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
|
||||
is_tp_rank_0, is_no_pp_or_last_stage
|
||||
from ._metric_hook import MetricHook
|
||||
|
||||
|
||||
def _format_number(val):
|
||||
if isinstance(val, float):
|
||||
return f'{val:.5f}'
|
||||
elif torch.is_floating_point(val):
|
||||
return f'{val.item():.5f}'
|
||||
return val
|
||||
|
||||
|
||||
class EpochIntervalHook(MetricHook):
|
||||
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
|
||||
super().__init__(trainer, priority)
|
||||
self._interval = interval
|
||||
|
||||
def _is_epoch_to_log(self):
|
||||
return self.trainer.cur_epoch % self._interval == 0
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByEpochHook(EpochIntervalHook):
|
||||
"""Specialized Hook to record the metric to log.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
:param interval: Recording interval
|
||||
:type interval: int, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None:
|
||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||
|
||||
def _get_str(self, mode):
|
||||
msg = []
|
||||
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
|
||||
msg.append(
|
||||
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
|
||||
msg = ', '.join(msg)
|
||||
return msg
|
||||
|
||||
def after_train_epoch(self):
|
||||
if self._is_epoch_to_log():
|
||||
msg = self._get_str(mode='train')
|
||||
|
||||
if self._is_rank_to_log:
|
||||
self.logger.info(
|
||||
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
def after_test_epoch(self):
|
||||
if self._is_epoch_to_log():
|
||||
msg = self._get_str(mode='test')
|
||||
if self._is_rank_to_log:
|
||||
self.logger.info(
|
||||
f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class TensorboardHook(MetricHook):
|
||||
"""Specialized Hook to record the metric to Tensorboard.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
:param log_dir: Directory of log
|
||||
:type log_dir: str, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None:
|
||||
super().__init__(trainer=trainer, priority=priority)
|
||||
self._is_rank_to_log = is_no_pp_or_last_stage()
|
||||
|
||||
if self._is_rank_to_log:
|
||||
# create workspace on only one rank
|
||||
if gpc.is_initialized(ParallelMode.GLOBAL):
|
||||
rank = gpc.get_global_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
log_dir = osp.join(log_dir, f'rank_{rank}')
|
||||
|
||||
# create workspace
|
||||
if not osp.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
self.writer = SummaryWriter(
|
||||
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
||||
|
||||
def after_train_iter(self, *args):
|
||||
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
|
||||
if metric_calculator.epoch_only:
|
||||
continue
|
||||
val = metric_calculator.get_last_step_value()
|
||||
if self._is_rank_to_log:
|
||||
self.writer.add_scalar(
|
||||
f'{metric_name}/train', val, self.trainer.cur_step)
|
||||
|
||||
def after_test_iter(self, *args):
|
||||
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
|
||||
if metric_calculator.epoch_only:
|
||||
continue
|
||||
val = metric_calculator.get_last_step_value()
|
||||
if self._is_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/test', val,
|
||||
self.trainer.cur_step)
|
||||
|
||||
def after_test_epoch(self):
|
||||
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
|
||||
if metric_calculator.epoch_only:
|
||||
val = metric_calculator.get_accumulated_value()
|
||||
if self._is_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/test', val,
|
||||
self.trainer.cur_step)
|
||||
|
||||
def after_train_epoch(self):
|
||||
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
|
||||
if metric_calculator.epoch_only:
|
||||
val = metric_calculator.get_accumulated_value()
|
||||
if self._is_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/train', val,
|
||||
self.trainer.cur_step)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogTimingByEpochHook(EpochIntervalHook):
|
||||
"""Specialized Hook to write timing record to log.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
:param interval: Recording interval
|
||||
:type interval: int, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int, optional
|
||||
:param log_eval: Whether writes in evaluation
|
||||
:type log_eval: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
trainer: Trainer,
|
||||
interval: int = 1,
|
||||
priority: int = 1,
|
||||
log_eval: bool = True
|
||||
) -> None:
|
||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||
set_global_multitimer_status(True)
|
||||
self._global_timer = get_global_multitimer()
|
||||
self._log_eval = log_eval
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
||||
|
||||
def _get_message(self):
|
||||
msg = []
|
||||
for timer_name, timer in self._global_timer:
|
||||
last_elapsed_time = timer.get_elapsed_time()
|
||||
if timer.has_history:
|
||||
history_mean = timer.get_history_mean()
|
||||
history_sum = timer.get_history_sum()
|
||||
msg.append(
|
||||
f'{timer_name}: last elapsed time = {last_elapsed_time}, '
|
||||
f'history sum = {history_sum}, history mean = {history_mean}')
|
||||
else:
|
||||
msg.append(
|
||||
f'{timer_name}: last elapsed time = {last_elapsed_time}')
|
||||
|
||||
msg = ', '.join(msg)
|
||||
return msg
|
||||
|
||||
def after_train_epoch(self):
|
||||
"""Writes log after finishing a training epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log() and self._is_rank_to_log:
|
||||
msg = self._get_message()
|
||||
self.logger.info(
|
||||
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
def after_test_epoch(self):
|
||||
"""Writes log after finishing a testing epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
|
||||
msg = self._get_message()
|
||||
self.logger.info(
|
||||
f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMemoryByEpochHook(EpochIntervalHook):
|
||||
"""Specialized Hook to write memory usage record to log.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:type trainer: Trainer
|
||||
:param interval: Recording interval
|
||||
:type interval: int, optional
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int, optional
|
||||
:param log_eval: Whether writes in evaluation
|
||||
:type log_eval: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
trainer: Trainer,
|
||||
interval: int = 1,
|
||||
priority: int = 1,
|
||||
log_eval: bool = True
|
||||
) -> None:
|
||||
super().__init__(trainer=trainer, interval=interval, priority=priority)
|
||||
set_global_multitimer_status(True)
|
||||
self._global_timer = get_global_multitimer()
|
||||
self._log_eval = log_eval
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
||||
|
||||
def before_train(self):
|
||||
"""Resets before training.
|
||||
"""
|
||||
if self._is_epoch_to_log() and self._is_rank_to_log:
|
||||
report_memory_usage('before-train')
|
||||
|
||||
def after_train_epoch(self):
|
||||
"""Writes log after finishing a training epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log() and self._is_rank_to_log:
|
||||
report_memory_usage(
|
||||
f'After Train - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
|
||||
|
||||
def after_test(self):
|
||||
"""Reports after testing.
|
||||
"""
|
||||
if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
|
||||
report_memory_usage(
|
||||
f'After Test - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
|
185
colossalai/trainer/hooks/_metric_hook.py
Normal file
185
colossalai/trainer/hooks/_metric_hook.py
Normal file
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
from ._base_hook import BaseHook
|
||||
from .._trainer import Trainer
|
||||
from ..metric import Loss, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
|
||||
|
||||
|
||||
class MetricHook(BaseHook):
|
||||
"""Specialized hook classes for :class:`Metric`.
|
||||
Some help metric collectors initialize, reset and
|
||||
update their states. Others are used to display and
|
||||
record the metric.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self, trainer: Trainer, priority: int):
|
||||
super().__init__(trainer, priority)
|
||||
self._is_stage_to_log = is_no_pp_or_last_stage()
|
||||
self._check_metric_states_initialization()
|
||||
|
||||
def _check_metric_states_initialization(self):
|
||||
if 'metrics' not in self.trainer.states:
|
||||
self.init_runner_states('metrics', dict(train={}, test={}))
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LossHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Loss`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
||||
super().__init__(trainer, priority)
|
||||
|
||||
if self._is_stage_to_log:
|
||||
self.metric = Loss(epoch_only=False)
|
||||
|
||||
# register the metric calculator
|
||||
self.trainer.states['metrics']['train'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
self.trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_train_epoch(self):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.reset()
|
||||
|
||||
def after_train_iter(self, logits, label, loss):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.update(loss)
|
||||
|
||||
def before_test_epoch(self):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, logits, label, loss):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.update(loss)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy2DHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy2D`.
|
||||
It acts the same as :class:`AccuracyHook`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
||||
super().__init__(trainer, priority)
|
||||
|
||||
if self._is_stage_to_log:
|
||||
self.metric = Accuracy2D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
self.trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, logits, label, *args):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy2p5DHook(MetricHook):
|
||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
||||
super().__init__(trainer, priority)
|
||||
|
||||
if self._is_stage_to_log:
|
||||
self.metric = Accuracy2p5D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
self.trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, logits, label, *args):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy3DHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy3D`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
trainer: Trainer,
|
||||
input_parallel_mode: ParallelMode,
|
||||
weight_parallel_mode: ParallelMode,
|
||||
priority: int = 10):
|
||||
super().__init__(trainer, priority)
|
||||
|
||||
if self._is_stage_to_log:
|
||||
self.metric = Accuracy3D(epoch_only=True,
|
||||
input_parallel_mode=input_parallel_mode,
|
||||
weight_parallel_mode=weight_parallel_mode)
|
||||
|
||||
# register the metric
|
||||
self.trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, logits, label, *args):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class AccuracyHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self, trainer: Trainer, priority: int = 10):
|
||||
super().__init__(trainer, priority)
|
||||
|
||||
if self._is_stage_to_log:
|
||||
self.metric = Accuracy(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
self.trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, logits, label, *args):
|
||||
if self._is_stage_to_log:
|
||||
self.metric.update(logits, label)
|
307
colossalai/trainer/metric.py
Normal file
307
colossalai/trainer/metric.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.communication import all_gather
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer._parallel_utilities import _gather
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_last_group
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
"""A basic class of metric collectors. It collects a specific
|
||||
metric during training or evaluation and it's always used with
|
||||
:class:`MetricHook` to help it update its states and show the
|
||||
metric. So please use corresponding hook class to make the metric
|
||||
collector works.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool):
|
||||
# is the metric only read for the full epoch
|
||||
self._epoch_only = epoch_only
|
||||
|
||||
@property
|
||||
def epoch_only(self):
|
||||
"""Returns :attr:`epoch_only`.
|
||||
"""
|
||||
return self._epoch_only
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Resets the metric to it's initial state.
|
||||
By default, this is called at the start of each epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
"""Updates the metric's state using the passed batch output.
|
||||
By default, this is called once for each batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_last_step_value(self):
|
||||
"""Returns the metric value in the last iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_accumulated_value(self):
|
||||
"""Computes the metric based on it's accumulated state.
|
||||
By default, this is called at the end of each epoch.
|
||||
|
||||
:return: the actual quantity of interest
|
||||
:rtype: Any
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def is_better(a, b) -> bool:
|
||||
"""Compares a and b, and returns whether a is better than b
|
||||
|
||||
:return: The result of comparison
|
||||
:rtype: bool
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Loss(Metric):
|
||||
"""A metric collector for loss.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_current_device())
|
||||
self.count = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.
|
||||
"""
|
||||
self.last_step_loss.zero_()
|
||||
self.accum_loss.zero_()
|
||||
self.count = 0
|
||||
|
||||
def update(self, loss) -> None:
|
||||
"""Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss.
|
||||
It expects the output has loss.
|
||||
|
||||
:param loss: Current loss of the output
|
||||
"""
|
||||
# expect output to be logits, label and loss
|
||||
loss_ = loss.detach()
|
||||
self.last_step_loss.copy_(loss_)
|
||||
self.accum_loss.add_(loss_)
|
||||
self.count += 1
|
||||
|
||||
def get_accumulated_value(self):
|
||||
"""Returns accumulated loss.
|
||||
"""
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
|
||||
|
||||
self.accum_loss.div_(self.count)
|
||||
return self.accum_loss.item()
|
||||
|
||||
def get_last_step_value(self):
|
||||
"""Returns :attr:`last_step_loss`.
|
||||
"""
|
||||
return self.last_step_loss
|
||||
|
||||
def is_better(a, b):
|
||||
return a < b
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_sum = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_step_sum.zero_()
|
||||
self.last_step_correct.zero_()
|
||||
self.accumulated_sum.zero_()
|
||||
self.accumulated_correct.zero_()
|
||||
|
||||
def update(self, logits, label) -> None:
|
||||
"""Updates last step accuracy and accumulated accuracy with current logits
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
:param logits: The logits output of the model
|
||||
:param label: The labels of the input data
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(label, (list, tuple)):
|
||||
label = label[0]
|
||||
|
||||
# update
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(label == preds)
|
||||
self.last_step_sum.fill_(label.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
def get_last_step_value(self):
|
||||
dist.all_reduce(self.last_step_sum,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
dist.all_reduce(self.last_step_correct,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
return (self.last_step_sum / self.last_step_correct).item()
|
||||
|
||||
def get_accumulated_value(self):
|
||||
dist.all_reduce(self.accumulated_sum,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
dist.all_reduce(self.accumulated_correct,
|
||||
group=gpc.get_group(ParallelMode.DATA))
|
||||
return (self.accumulated_correct / self.accumulated_sum).item()
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
|
||||
class Accuracy2D(Accuracy):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks. This class is the same as :class:`Accuracy` but used in 2D
|
||||
model parallelism.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
|
||||
def update(self, logits, label) -> None:
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(label, (list, tuple)):
|
||||
label = label[0]
|
||||
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_2D_ROW,
|
||||
1
|
||||
)
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_2D_COL,
|
||||
0,
|
||||
)
|
||||
# update
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(label == preds)
|
||||
self.last_step_sum.fill_(label.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
|
||||
class Accuracy2p5D(Accuracy):
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
|
||||
def update(self, logits, label) -> None:
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(label, (list, tuple)):
|
||||
label = label[0]
|
||||
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
1
|
||||
)
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
0,
|
||||
)
|
||||
logits = _gather(
|
||||
logits,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
0,
|
||||
)
|
||||
# update
|
||||
preds = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(label == preds)
|
||||
self.last_step_sum.fill_(label.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
|
||||
class Accuracy3D(Accuracy):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks. This class is the same as :class:`Accuracy` but used in 3D
|
||||
model parallelism.
|
||||
|
||||
:param input_parallel_mode: The parallel mode of the input, generally it should be `ParallelMode.PARALLEL_3D_OUTPUT`
|
||||
:type input_parallel_mode: `ParallelMode`
|
||||
:param weight_parallel_mode: The parallel mode of the weight, generally it should be `ParallelMode.PARALLEL_3D_WEIGHT`
|
||||
:type weight_parallel_mode: `ParallelMode`
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
|
||||
def __init__(self, epoch_only, input_parallel_mode, weight_parallel_mode):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.depth = int(os.environ['DEPTH_3D'])
|
||||
self.input_parallel_mode = input_parallel_mode
|
||||
self.weight_parallel_mode = weight_parallel_mode
|
||||
self.output_parallel_mode = get_last_group(input_parallel_mode,
|
||||
weight_parallel_mode)
|
||||
|
||||
def update(self, logits, target):
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(target, (list, tuple)):
|
||||
target = target[0]
|
||||
|
||||
batch_size = target.size(0)
|
||||
|
||||
j = gpc.get_local_rank(self.input_parallel_mode)
|
||||
i = gpc.get_local_rank(self.weight_parallel_mode)
|
||||
target = torch.chunk(target, self.depth, dim=0)[i]
|
||||
target = torch.chunk(target, self.depth, dim=0)[j]
|
||||
|
||||
logits = all_gather(logits, -1, self.output_parallel_mode)
|
||||
prediction = torch.argmax(logits, dim=-1)
|
||||
correct = torch.sum(prediction == target)
|
||||
|
||||
dist.all_reduce(correct, group=gpc.get_group(self.input_parallel_mode))
|
||||
dist.all_reduce(correct,
|
||||
group=gpc.get_group(self.weight_parallel_mode))
|
||||
|
||||
self.last_step_sum.fill_(batch_size)
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
Reference in New Issue
Block a user