mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
Layer integration (#83)
* integrated parallel layers for ease of building models * integrated 2.5d layers * cleaned codes and unit tests * added log metric by step hook; updated imagenet benchmark; fixed some bugs * reworked initialization; cleaned codes Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
@@ -1,15 +1,12 @@
|
||||
from ._base_hook import BaseHook
|
||||
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
||||
from ._metric_hook import (LossHook, Accuracy2DHook, AccuracyHook, MetricHook,
|
||||
Accuracy1DHook, Accuracy2p5DHook, Accuracy3DHook)
|
||||
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
||||
from ._checkpoint_hook import LoadCheckpointHook, SaveCheckpointHook
|
||||
from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook,
|
||||
TensorboardHook)
|
||||
from ._lr_scheduler_hook import LRSchedulerHook
|
||||
from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook
|
||||
|
||||
__all__ = [
|
||||
'BaseHook', 'MetricHook',
|
||||
'LoadCheckpointHook', 'SaveCheckpointHook',
|
||||
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
||||
'Accuracy1DHook', 'Accuracy2p5DHook', 'Accuracy3DHook',
|
||||
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
||||
'LRSchedulerHook'
|
||||
'BaseHook', 'MetricHook', 'LoadCheckpointHook', 'SaveCheckpointHook', 'LossHook', 'AccuracyHook',
|
||||
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LRSchedulerHook',
|
||||
'ThroughputHook', 'LogMetricByStepHook'
|
||||
]
|
||||
|
@@ -16,11 +16,11 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
|
||||
from ._base_hook import BaseHook
|
||||
|
||||
|
||||
def _format_number(val):
|
||||
def _format_number(val, prec=5):
|
||||
if isinstance(val, float):
|
||||
return f'{val:.5g}'
|
||||
return f'{val:.{prec}g}'
|
||||
elif torch.is_tensor(val) and torch.is_floating_point(val):
|
||||
return f'{val.item():.5g}'
|
||||
return f'{val.item():.{prec}g}'
|
||||
return val
|
||||
|
||||
|
||||
@@ -37,6 +37,24 @@ class LogByEpochHook(BaseHook):
|
||||
return trainer.cur_epoch % self._interval == 0
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByStepHook(BaseHook):
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
|
||||
trainer.states['step_metrics'][metric_name.lower()] = \
|
||||
f'{_format_number(metric_calculator.get_last_step_value())}'
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
|
||||
trainer.states['step_metrics'][metric_name.lower()] = \
|
||||
f'{_format_number(metric_calculator.get_last_step_value())}'
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class LogMetricByEpochHook(LogByEpochHook):
|
||||
"""Specialized Hook to record the metric to log.
|
||||
@@ -61,7 +79,7 @@ class LogMetricByEpochHook(LogByEpochHook):
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
msg.append(
|
||||
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
|
||||
msg = ', '.join(msg)
|
||||
msg = ' | '.join(msg)
|
||||
return msg
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
@@ -69,15 +87,15 @@ class LogMetricByEpochHook(LogByEpochHook):
|
||||
msg = self._get_str(trainer=trainer, mode='train')
|
||||
|
||||
if self._is_rank_to_log:
|
||||
self.logger.info(
|
||||
f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}')
|
||||
# f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
def after_test_epoch(self, trainer):
|
||||
if self._is_epoch_to_log(trainer):
|
||||
msg = self._get_str(trainer=trainer, mode='test')
|
||||
if self._is_rank_to_log:
|
||||
self.logger.info(
|
||||
f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
|
||||
# f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
@@ -131,8 +149,7 @@ class TensorboardHook(BaseHook):
|
||||
log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}')
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
self.writer = SummaryWriter(
|
||||
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
||||
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_rank_{rank}')
|
||||
|
||||
def _log_by_iter(self, trainer, mode: str):
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
@@ -141,16 +158,14 @@ class TensorboardHook(BaseHook):
|
||||
val = metric_calculator.get_last_step_value()
|
||||
|
||||
if self._is_valid_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val,
|
||||
trainer.cur_step)
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
|
||||
|
||||
def _log_by_epoch(self, trainer, mode: str):
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
if metric_calculator.epoch_only:
|
||||
val = metric_calculator.get_accumulated_value()
|
||||
if self._is_valid_rank_to_log:
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val,
|
||||
trainer.cur_step)
|
||||
self.writer.add_scalar(f'{metric_name}/{mode}', val, trainer.cur_step)
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
self._log_by_iter(trainer, mode='test')
|
||||
@@ -178,15 +193,13 @@ class LogTimingByEpochHook(LogByEpochHook):
|
||||
:param log_eval: Whether writes in evaluation
|
||||
:type log_eval: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
timer: MultiTimer,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
ignore_num_train_steps: int = 0
|
||||
) -> None:
|
||||
ignore_num_train_steps: int = 0) -> None:
|
||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||
self._timer = timer
|
||||
self._log_eval = log_eval
|
||||
@@ -197,40 +210,39 @@ class LogTimingByEpochHook(LogByEpochHook):
|
||||
self._ignore_num_train_steps = ignore_num_train_steps
|
||||
self._is_train_step_history_trimmed = False
|
||||
|
||||
def _get_message(self):
|
||||
def _get_message(self, mode):
|
||||
msg = []
|
||||
for timer_name, timer in self._timer:
|
||||
last_elapsed_time = timer.get_elapsed_time()
|
||||
if timer.has_history:
|
||||
if timer_name == 'train-step' and not self._is_train_step_history_trimmed:
|
||||
timer._history = timer._history[self._ignore_num_train_steps:]
|
||||
self._is_train_step_history_trimmed = True
|
||||
history_mean = timer.get_history_mean()
|
||||
history_sum = timer.get_history_sum()
|
||||
msg.append(
|
||||
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s')
|
||||
else:
|
||||
msg.append(
|
||||
f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
|
||||
if timer_name.startswith(mode):
|
||||
last_elapsed_time = timer.get_elapsed_time()
|
||||
if timer.has_history:
|
||||
if timer_name == 'Train-step' and not self._is_train_step_history_trimmed:
|
||||
timer._history = timer._history[self._ignore_num_train_steps:]
|
||||
self._is_train_step_history_trimmed = True
|
||||
history_mean = timer.get_history_mean()
|
||||
history_sum = timer.get_history_sum()
|
||||
msg.append(
|
||||
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s'
|
||||
)
|
||||
else:
|
||||
msg.append(f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
|
||||
|
||||
msg = ', '.join(msg)
|
||||
msg = ' | '.join(msg)
|
||||
return msg
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Writes log after finishing a training epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
msg = self._get_message()
|
||||
self.logger.info(
|
||||
f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}, num steps per epoch={trainer.steps_per_epoch}')
|
||||
msg = self._get_message('Train')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Train]: {msg}, #steps/epoch = {trainer.steps_per_epoch}')
|
||||
|
||||
def after_test_epoch(self, trainer):
|
||||
"""Writes log after finishing a testing epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
|
||||
msg = self._get_message()
|
||||
self.logger.info(
|
||||
f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
|
||||
msg = self._get_message('Test')
|
||||
self.logger.info(f'[Epoch {trainer.cur_epoch} / Test]: {msg}')
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
@@ -246,14 +258,12 @@ class LogMemoryByEpochHook(LogByEpochHook):
|
||||
:param log_eval: Whether writes in evaluation
|
||||
:type log_eval: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
report_cpu: bool = False
|
||||
) -> None:
|
||||
report_cpu: bool = False) -> None:
|
||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||
self._log_eval = log_eval
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
||||
@@ -262,20 +272,16 @@ class LogMemoryByEpochHook(LogByEpochHook):
|
||||
"""Resets before training.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
report_memory_usage('before-train', self.logger)
|
||||
report_memory_usage('Before-train', self.logger)
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
"""Writes log after finishing a training epoch.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
|
||||
report_memory_usage(
|
||||
f'After Train - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
|
||||
self.logger)
|
||||
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Train]', self.logger)
|
||||
|
||||
def after_test(self, trainer):
|
||||
"""Reports after testing.
|
||||
"""
|
||||
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
|
||||
report_memory_usage(
|
||||
f'After Test - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
|
||||
self.logger)
|
||||
report_memory_usage(f'[Epoch {trainer.cur_epoch} / Test]', self.logger)
|
||||
|
@@ -1,9 +1,7 @@
|
||||
from colossalai.registry import HOOKS
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.builder import build_lr_scheduler
|
||||
from colossalai.registry import HOOKS
|
||||
from ._metric_hook import MetricHook
|
||||
from ..metric import LearningRate
|
||||
from ._metric_hook import LearningRateMetric, MetricHook
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
@@ -19,28 +17,28 @@ class LRSchedulerHook(MetricHook):
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lr_scheduler,
|
||||
by_epoch: bool,
|
||||
store_lr_in_state: bool = True,
|
||||
priority: int = 1,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
lr_scheduler,
|
||||
by_epoch: bool,
|
||||
store_lr_in_state: bool = True,
|
||||
priority: int = 1,
|
||||
):
|
||||
super().__init__(priority=priority)
|
||||
self.by_epoch = by_epoch
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.store_lr_in_state = store_lr_in_state
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=self.by_epoch,
|
||||
initial_lr=self.lr_scheduler.get_last_lr()[0])
|
||||
trainer.states['metrics']['train']['LR'] = LearningRateMetric(epoch_only=self.by_epoch,
|
||||
initial_lr=self.lr_scheduler.get_last_lr()[0])
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if self.by_epoch:
|
||||
self.lr_scheduler.step()
|
||||
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
|
||||
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
if not self.by_epoch:
|
||||
self.lr_scheduler.step()
|
||||
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
trainer.states['metrics']['train']['LR'].update(self.lr_scheduler.get_last_lr()[0])
|
||||
|
@@ -1,11 +1,209 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication import all_reduce
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import HOOKS
|
||||
from colossalai.utils import is_no_pp_or_last_stage
|
||||
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
from ..metric import Loss, Accuracy1D, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
"""A basic class of metric collectors. It collects a specific
|
||||
metric during training or evaluation and it's always used with
|
||||
:class:`MetricHook` to help it update its states and show the
|
||||
metric. So please use corresponding hook class to make the metric
|
||||
collector works.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool):
|
||||
# is the metric only read for the full epoch
|
||||
self._epoch_only = epoch_only
|
||||
|
||||
@property
|
||||
def epoch_only(self):
|
||||
"""Returns :attr:`epoch_only`.
|
||||
"""
|
||||
return self._epoch_only
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Resets the metric to it's initial state.
|
||||
By default, this is called at the start of each epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs) -> None:
|
||||
"""Updates the metric's state using the passed batch output.
|
||||
By default, this is called once for each batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_last_step_value(self):
|
||||
"""Returns the metric value in the last iteration.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_accumulated_value(self):
|
||||
"""Computes the metric based on it's accumulated state.
|
||||
By default, this is called at the end of each epoch.
|
||||
|
||||
:return: the actual quantity of interest
|
||||
:rtype: Any
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def is_better(a, b) -> bool:
|
||||
"""Compares a and b, and returns whether a is better than b
|
||||
|
||||
:return: The result of comparison
|
||||
:rtype: bool
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LossMetric(Metric):
|
||||
"""A metric collector for loss.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_current_device())
|
||||
self.count = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Sets :attr:`last_step_loss` and :attr:`accum_loss` to zero.
|
||||
"""
|
||||
self.last_step_loss.zero_()
|
||||
self.accum_loss.zero_()
|
||||
self.count = 0
|
||||
|
||||
def update(self, loss) -> None:
|
||||
"""Updates :attr:`last_step_loss` and :attr:`accum_loss` with current loss.
|
||||
It expects the output has loss.
|
||||
|
||||
:param loss: Current loss of the output
|
||||
"""
|
||||
# expect output to be logits, label and loss
|
||||
loss_ = loss.detach()
|
||||
self.last_step_loss.copy_(loss_)
|
||||
self.accum_loss.add_(loss_)
|
||||
self.count += 1
|
||||
|
||||
def get_accumulated_value(self):
|
||||
"""Returns accumulated loss.
|
||||
"""
|
||||
if gpc.is_initialized(ParallelMode.DATA):
|
||||
dist.all_reduce(self.accum_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA))
|
||||
self.accum_loss.div_(gpc.get_world_size(ParallelMode.DATA))
|
||||
|
||||
self.accum_loss.div_(self.count)
|
||||
return self.accum_loss.item()
|
||||
|
||||
def get_last_step_value(self):
|
||||
"""Returns :attr:`last_step_loss`.
|
||||
"""
|
||||
return self.last_step_loss
|
||||
|
||||
def is_better(a, b):
|
||||
return a < b
|
||||
|
||||
|
||||
class LearningRateMetric(Metric):
|
||||
"""A metric collector for learning rate.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.lr = initial_lr
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def update(self, lr) -> None:
|
||||
self.lr = lr
|
||||
|
||||
def get_last_step_value(self):
|
||||
return self.lr
|
||||
|
||||
def get_accumulated_value(self):
|
||||
return self.lr
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class AccuracyMetric(Metric):
|
||||
"""A metric collector for accuracy. It only works for classification
|
||||
tasks.
|
||||
|
||||
:param epoch_only: Whether the metric only read for the full epoch
|
||||
:type epoch_only: bool
|
||||
"""
|
||||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.acc = accuracy_func
|
||||
self.last_step_sum = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_step_sum.zero_()
|
||||
self.last_step_correct.zero_()
|
||||
self.accumulated_sum.zero_()
|
||||
self.accumulated_correct.zero_()
|
||||
|
||||
def update(self, logits, targets) -> None:
|
||||
"""Updates last step accuracy and accumulated accuracy with current logits
|
||||
and labels. It expects the output has logits and labels.
|
||||
|
||||
:param logits: The logits output of the model
|
||||
:param label: The labels of the input data
|
||||
"""
|
||||
if isinstance(logits, (list, tuple)):
|
||||
logits = logits[0]
|
||||
if isinstance(targets, (list, tuple)):
|
||||
targets = targets[0]
|
||||
# update
|
||||
correct = self.acc(logits, targets)
|
||||
|
||||
self.last_step_sum.fill_(targets.size(0))
|
||||
self.last_step_correct.fill_(correct)
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
def get_last_step_value(self):
|
||||
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
|
||||
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
|
||||
return (self.last_step_correct / self.last_step_sum).item()
|
||||
|
||||
def get_accumulated_value(self):
|
||||
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
|
||||
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
|
||||
return (self.accumulated_correct / self.accumulated_sum).item()
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
return a > b
|
||||
|
||||
|
||||
class MetricHook(BaseHook):
|
||||
@@ -19,10 +217,10 @@ class MetricHook(BaseHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
priority: int,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
priority: int,
|
||||
):
|
||||
super().__init__(priority)
|
||||
self._is_stage_to_compute = is_no_pp_or_last_stage()
|
||||
|
||||
@@ -40,7 +238,6 @@ class LossHook(MetricHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
|
||||
@@ -48,14 +245,12 @@ class LossHook(MetricHook):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
|
||||
if self._is_stage_to_compute:
|
||||
self.train_loss = Loss(epoch_only=False)
|
||||
self.test_loss = Loss(epoch_only=True)
|
||||
self.train_loss = LossMetric(epoch_only=False)
|
||||
self.test_loss = LossMetric(epoch_only=True)
|
||||
|
||||
# register the metric calculator
|
||||
trainer.states['metrics']['train'][
|
||||
self.train_loss.__class__.__name__] = self.train_loss
|
||||
trainer.states['metrics']['test'][
|
||||
self.test_loss.__class__.__name__] = self.test_loss
|
||||
trainer.states['metrics']['train']['Loss'] = self.train_loss
|
||||
trainer.states['metrics']['test']['Loss'] = self.test_loss
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
@@ -74,124 +269,6 @@ class LossHook(MetricHook):
|
||||
self.test_loss.update(loss)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy1DHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy1D`.
|
||||
It acts the same as :class:`AccuracyHook`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy1D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy2DHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy2D`.
|
||||
It acts the same as :class:`AccuracyHook`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int, optional
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy2D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy2p5DHook(MetricHook):
|
||||
def __init__(self, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy2p5D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class Accuracy3DHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy3D`.
|
||||
|
||||
:param trainer: Trainer attached with current hook
|
||||
:param priority: Priority in the printing, hooks with small priority will be printed in front
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy3D(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class AccuracyHook(MetricHook):
|
||||
"""Specialized hook class for :class:`Accuracy`.
|
||||
@@ -201,22 +278,87 @@ class AccuracyHook(MetricHook):
|
||||
:type trainer: Trainer
|
||||
:type priority: int
|
||||
"""
|
||||
|
||||
def __init__(self, priority: int = 0):
|
||||
def __init__(self, accuracy_func: Callable, priority: int = 0):
|
||||
super().__init__(priority)
|
||||
self.accuracy_func = accuracy_func
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = Accuracy(epoch_only=True)
|
||||
self.metric = AccuracyMetric(epoch_only=True, accuracy_func=self.accuracy_func)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['test'][
|
||||
self.metric.__class__.__name__] = self.metric
|
||||
trainer.states['metrics']['test']['Accuracy'] = self.metric
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, label, *args):
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(logits, label)
|
||||
self.metric.update(logits, targets)
|
||||
|
||||
|
||||
class ThroughputMetric(Metric):
|
||||
def __init__(self, epoch_only: bool):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_used_time = torch.zeros(1, device=get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.accumulated_num_samples.zero_()
|
||||
self.accumulated_used_time.zero_()
|
||||
self.last_step_num_samples.zero_()
|
||||
self.last_step_used_time.zero_()
|
||||
|
||||
def update(self, tensor, time) -> None:
|
||||
if isinstance(tensor, (list, tuple)):
|
||||
tensor = tensor[0]
|
||||
self.last_step_num_samples.fill_(tensor.size(0))
|
||||
self.last_step_used_time.fill_(time)
|
||||
self.accumulated_num_samples += self.last_step_num_samples
|
||||
self.accumulated_used_time += self.last_step_used_time
|
||||
|
||||
def get_last_step_value(self):
|
||||
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
|
||||
return (self.last_step_num_samples / (self.last_step_used_time + 1e-12)).item()
|
||||
|
||||
def get_accumulated_value(self):
|
||||
self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
|
||||
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
|
||||
|
||||
def is_better(a, b) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
class ThroughputHook(MetricHook):
|
||||
def __init__(self, priority: int = 10):
|
||||
super().__init__(priority)
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = ThroughputMetric(epoch_only=True)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['train']['Throughput'] = self.metric
|
||||
trainer.states['metrics']['test']['Throughput'] = self.metric
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
self.metric.reset()
|
||||
|
||||
def after_train_iter(self, trainer, logits, targets, *args):
|
||||
self.metric.update(targets, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
self.metric.reset()
|
||||
|
||||
def after_test_iter(self, trainer, logits, targets, *args):
|
||||
self.metric.update(targets, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
|
Reference in New Issue
Block a user