Develop/experiments (#59)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b7699.

* improved consistency between trainer, engine and schedule (#23)

Co-authored-by: 1SAA <c2h214748@gmail.com>

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline

Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability

Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
This commit is contained in:
Frank Lee
2021-12-09 15:08:29 +08:00
committed by GitHub
parent eb2f8b1f6b
commit da01c234e1
229 changed files with 6532 additions and 8741 deletions

View File

@@ -5,9 +5,6 @@ from abc import ABC
from torch import Tensor
from colossalai.logging import get_global_dist_logger
from .._trainer import Trainer
class BaseHook(ABC):
"""This class allows users to add desired actions in specific time points
@@ -18,27 +15,31 @@ class BaseHook(ABC):
:type trainer: Trainer
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int) -> None:
self.trainer = trainer
self.priority = priority
self.logger = get_global_dist_logger()
def before_train(self):
def __init__(self, priority: int) -> None:
self.priority = priority
def after_hook_is_attached(self, trainer):
"""Actions after hooks are attached to trainer.
"""
pass
def before_train(self, trainer):
"""Actions before training.
"""
pass
def after_train(self):
def after_train(self, trainer):
"""Actions after training.
"""
pass
def before_train_iter(self):
def before_train_iter(self, trainer):
"""Actions before running a training iteration.
"""
pass
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a training iteration.
:param output: Output of the model
@@ -50,42 +51,42 @@ class BaseHook(ABC):
"""
pass
def before_train_epoch(self):
def before_train_epoch(self, trainer):
"""Actions before starting a training epoch.
"""
pass
def after_train_epoch(self):
def after_train_epoch(self, trainer):
"""Actions after finishing a training epoch.
"""
pass
def before_test(self):
def before_test(self, trainer):
"""Actions before evaluation.
"""
pass
def after_test(self):
def after_test(self, trainer):
"""Actions after evaluation.
"""
pass
def before_test_epoch(self):
def before_test_epoch(self, trainer):
"""Actions before starting a testing epoch.
"""
pass
def after_test_epoch(self):
def after_test_epoch(self, trainer):
"""Actions after finishing a testing epoch.
"""
pass
def before_test_iter(self):
def before_test_iter(self, trainer):
"""Actions before running a testing iteration.
"""
pass
def after_test_iter(self, output: Tensor, label: Tensor, loss: Tensor):
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a testing iteration.
:param output: Output of the model
@@ -97,11 +98,11 @@ class BaseHook(ABC):
"""
pass
def init_runner_states(self, key, val):
def init_runner_states(self, trainer, key, val):
"""Initializes trainer's state.
:param key: Key of reseting state
:param val: Value of reseting state
"""
if key not in self.trainer.states:
self.trainer.states[key] = val
if key not in trainer.states:
trainer.states[key] = val

View File

@@ -2,9 +2,9 @@
# -*- encoding: utf-8 -*-
import os.path as osp
from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS
from colossalai.trainer import Trainer
from colossalai.trainer.hooks import BaseHook
from colossalai.utils import is_dp_rank_0
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
@@ -16,12 +16,10 @@ from ._lr_scheduler_hook import LRSchedulerHook
class SaveCheckpointHook(BaseHook):
"""Saves the model by interval in training process.
:param trainer: Trainer attached with current hook
:param interval: Saving interval
:param checkpoint_dir: Directory of saving checkpoint
:param suffix: Saving suffix of the file
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type interval: int, optional
:type checkpoint_dir: int, optional
:type suffix: str, optional
@@ -29,59 +27,55 @@ class SaveCheckpointHook(BaseHook):
"""
def __init__(self,
trainer: Trainer,
interval: int = 1,
checkpoint_dir: str = None,
suffix: str = '',
priority: int = 10):
super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
super().__init__(priority=priority)
self.interval = interval
self.checkpoint_dir = checkpoint_dir
self.suffix = suffix
self.logger = get_dist_logger()
# get lr scheduler from the LRSchedulerHook before train
self._lr_scheduler = None
def before_train(self):
def after_hook_is_attached(self, trainer):
# check if lr scheduler is present in LRSchedulerHook
for hook in self.trainer.hooks:
for hook in trainer.hooks:
if isinstance(hook, LRSchedulerHook):
self._lr_scheduler = hook.lr_scheduler
break
def after_train_epoch(self):
def after_train_epoch(self, trainer):
"""Saves the model after a training epoch.
"""
# save by interval
if self.trainer.cur_epoch % self.interval == 0:
if trainer.cur_epoch % self.interval == 0:
# only gpus with data parallel rank equals to 0 write to the disk
if is_dp_rank_0():
save_path = get_checkpoint_path(self.checkpoint_dir,
self.trainer.cur_epoch,
trainer.cur_epoch,
suffix=self.suffix)
save_checkpoint(save_path,
self.trainer.cur_epoch,
self.trainer.engine.model,
self.trainer.engine.optimizer,
trainer.cur_epoch,
trainer.engine.model,
trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info(
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
f'checkpoint for epoch {trainer.cur_epoch} is saved to {self.checkpoint_dir}', ranks=[0])
@HOOKS.register_module
class LoadCheckpointHook(BaseHook):
"""Loads the model before training process.
:param trainer: Trainer attached with current hook
:param checkpoint_dir: Directory of saving checkpoint
:param epoch: Epoch number to be set
:param finetune: Whether allows to load a part of the model
:param strict: Whether loads a model that has the same shape of parameters
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type checkpoint_dir: str, optional
:type epoch: str, optional
:type finetune: bool, optional
@@ -90,28 +84,26 @@ class LoadCheckpointHook(BaseHook):
"""
def __init__(self,
trainer: Trainer = None,
checkpoint_dir: str = None,
epoch: int = -1,
finetune: bool = False,
strict: bool = False,
suffix: str = '',
priority: int = 0) -> None:
super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
super().__init__(priority=priority)
self.epoch = epoch
self.checkpoint_dir = checkpoint_dir
self.finetune = finetune
self.suffix = suffix
self.strict = strict
self.logger = get_dist_logger()
def before_train(self):
def before_train(self, trainer):
"""Loads parameters to the model before training.
"""
# check if lr scheduler is present in LRSchedulerHook
lr_scheduler = None
for hook in self.trainer.hooks:
for hook in trainer.hooks:
if isinstance(hook, LRSchedulerHook):
lr_scheduler = hook.lr_scheduler
break
@@ -124,17 +116,17 @@ class LoadCheckpointHook(BaseHook):
if osp.exists(path):
last_epoch, _ = load_checkpoint(path,
self.trainer.engine.model,
self.trainer.engine.optimizer,
trainer.engine.model,
trainer.engine.optimizer,
lr_scheduler,
finetune=self.finetune,
strict=self.strict)
if self.finetune:
self.trainer.cur_epoch = 0
trainer.cur_epoch = 0
else:
self.trainer.cur_epoch = last_epoch
trainer.cur_epoch = last_epoch
self.logger.info(
f'loaded checkpoint from {path}')
f'loaded checkpoint from {path}', ranks=[0])
else:
raise FileNotFoundError(f'checkpoint is not found at {path}')

View File

@@ -6,35 +6,40 @@ import os.path as osp
import torch
from torch.utils.tensorboard import SummaryWriter
from typing import List
from decimal import Decimal
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
from colossalai.trainer._trainer import Trainer
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage
from colossalai.logging import DistributedLogger
from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from ._base_hook import BaseHook
def _format_number(val):
if isinstance(val, float):
return f'{val:.5f}'
elif torch.is_floating_point(val):
return f'{val.item():.5f}'
return f'{val:.5g}'
elif torch.is_tensor(val) and torch.is_floating_point(val):
return f'{val.item():.5g}'
return val
class EpochIntervalHook(BaseHook):
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
super().__init__(trainer, priority)
class LogByEpochHook(BaseHook):
def __init__(self,
logger,
interval: int = 1,
priority: int = 1):
super().__init__(priority)
self.logger = logger
self._interval = interval
def _is_epoch_to_log(self):
return self.trainer.cur_epoch % self._interval == 0
def _is_epoch_to_log(self, trainer):
return trainer.cur_epoch % self._interval == 0
@HOOKS.register_module
class LogMetricByEpochHook(EpochIntervalHook):
class LogMetricByEpochHook(LogByEpochHook):
"""Specialized Hook to record the metric to log.
:param trainer: Trainer attached with current hook
@@ -45,32 +50,35 @@ class LogMetricByEpochHook(EpochIntervalHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
def __init__(self,
logger,
interval: int = 1,
priority: int = 10) -> None:
super().__init__(logger, interval, priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
def _get_str(self, mode):
def _get_str(self, trainer, mode):
msg = []
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
msg.append(
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
msg = ', '.join(msg)
return msg
def after_train_epoch(self):
if self._is_epoch_to_log():
msg = self._get_str(mode='train')
def after_train_epoch(self, trainer):
if self._is_epoch_to_log(trainer):
msg = self._get_str(trainer=trainer, mode='train')
if self._is_rank_to_log:
self.logger.info(
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
def after_test_epoch(self):
if self._is_epoch_to_log():
msg = self._get_str(mode='test')
def after_test_epoch(self, trainer):
if self._is_epoch_to_log(trainer):
msg = self._get_str(trainer=trainer, mode='test')
if self._is_rank_to_log:
self.logger.info(
f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
@HOOKS.register_module
@@ -86,74 +94,79 @@ class TensorboardHook(BaseHook):
"""
def __init__(self,
trainer: Trainer,
log_dir: str,
dp_rank_0_only: bool = True,
tp_rank_0_only: bool = True,
ranks: List = None,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
priority: int = 10,
) -> None:
super().__init__(trainer=trainer, priority=priority)
super().__init__(priority=priority)
# create log dir
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
os.makedirs(log_dir, exist_ok=True)
# determine the ranks to generate tensorboard logs
self._is_valid_rank_to_log = is_no_pp_or_last_stage()
self._is_valid_rank_to_log = False
if not gpc.is_initialized(parallel_mode):
self._is_valid_rank_to_log = True
else:
local_rank = gpc.get_local_rank(parallel_mode)
if dp_rank_0_only:
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0()
if ranks is None or local_rank in ranks:
self._is_valid_rank_to_log = True
if tp_rank_0_only:
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0()
# check for
if gpc.is_initialized(ParallelMode.PIPELINE) and \
not gpc.is_last_rank(ParallelMode.PIPELINE) and self._is_valid_rank_to_log:
raise ValueError("Tensorboard hook can only log on the last rank of pipeline process group")
if self._is_valid_rank_to_log:
# create workspace on only one rank
if gpc.is_initialized(ParallelMode.GLOBAL):
rank = gpc.get_global_rank()
if gpc.is_initialized(parallel_mode):
rank = gpc.get_local_rank(parallel_mode)
else:
rank = 0
# create workspace
log_dir = osp.join(log_dir, f'rank_{rank}')
log_dir = osp.join(log_dir, f'{parallel_mode}_rank_{rank}')
os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter(
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
def _log_by_iter(self, mode: str):
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
def _log_by_iter(self, trainer, mode: str):
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
self.trainer.cur_step)
trainer.cur_step)
def _log_by_epoch(self, mode: str):
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
def _log_by_epoch(self, trainer, mode: str):
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/{mode}', val,
self.trainer.cur_step)
trainer.cur_step)
def after_test_iter(self, *args):
self._log_by_iter(mode='test')
def after_test_iter(self, trainer, *args):
self._log_by_iter(trainer, mode='test')
def after_test_epoch(self):
self._log_by_epoch(mode='test')
def after_test_epoch(self, trainer):
self._log_by_epoch(trainer, mode='test')
def after_train_iter(self, *args):
self._log_by_iter(mode='train')
def after_train_iter(self, trainer, *args):
self._log_by_iter(trainer, mode='train')
def after_train_epoch(self):
self._log_by_epoch(mode='train')
def after_train_epoch(self, trainer):
self._log_by_epoch(trainer, mode='train')
@HOOKS.register_module
class LogTimingByEpochHook(EpochIntervalHook):
class LogTimingByEpochHook(LogByEpochHook):
"""Specialized Hook to write timing record to log.
:param trainer: Trainer attached with current hook
@@ -167,53 +180,61 @@ class LogTimingByEpochHook(EpochIntervalHook):
"""
def __init__(self,
trainer: Trainer,
timer: MultiTimer,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True
log_eval: bool = True,
ignore_num_train_steps: int = 0
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
set_global_multitimer_status(True)
self._global_timer = get_global_multitimer()
super().__init__(logger=logger, interval=interval, priority=priority)
self._timer = timer
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
# extra handling to avoid the unstable readings of the first
# few training steps to affect the history mean time
self._ignore_num_train_steps = ignore_num_train_steps
self._is_train_step_history_trimmed = False
def _get_message(self):
msg = []
for timer_name, timer in self._global_timer:
for timer_name, timer in self._timer:
last_elapsed_time = timer.get_elapsed_time()
if timer.has_history:
if timer_name == 'train-step' and not self._is_train_step_history_trimmed:
timer._history = timer._history[self._ignore_num_train_steps:]
self._is_train_step_history_trimmed = True
history_mean = timer.get_history_mean()
history_sum = timer.get_history_sum()
msg.append(
f'{timer_name}: last elapsed time = {last_elapsed_time}, '
f'history sum = {history_sum}, history mean = {history_mean}')
f'{timer_name}: last = {_format_number(last_elapsed_time)} s, mean = {_format_number(history_mean)} s')
else:
msg.append(
f'{timer_name}: last elapsed time = {last_elapsed_time}')
f'{timer_name}: last = {_format_number(last_elapsed_time)} s')
msg = ', '.join(msg)
return msg
def after_train_epoch(self):
def after_train_epoch(self, trainer):
"""Writes log after finishing a training epoch.
"""
if self._is_epoch_to_log() and self._is_rank_to_log:
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
msg = self._get_message()
self.logger.info(
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
f'Training - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}, num steps per epoch={trainer.steps_per_epoch}')
def after_test_epoch(self):
def after_test_epoch(self, trainer):
"""Writes log after finishing a testing epoch.
"""
if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
msg = self._get_message()
self.logger.info(
f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
f'Testing - Epoch {trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
@HOOKS.register_module
class LogMemoryByEpochHook(EpochIntervalHook):
class LogMemoryByEpochHook(LogByEpochHook):
"""Specialized Hook to write memory usage record to log.
:param trainer: Trainer attached with current hook
@@ -227,33 +248,34 @@ class LogMemoryByEpochHook(EpochIntervalHook):
"""
def __init__(self,
trainer: Trainer,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True
log_eval: bool = True,
report_cpu: bool = False
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
set_global_multitimer_status(True)
self._global_timer = get_global_multitimer()
super().__init__(logger=logger, interval=interval, priority=priority)
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
def before_train(self):
def before_train(self, trainer):
"""Resets before training.
"""
if self._is_epoch_to_log() and self._is_rank_to_log:
report_memory_usage('before-train')
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
report_memory_usage('before-train', self.logger)
def after_train_epoch(self):
def after_train_epoch(self, trainer):
"""Writes log after finishing a training epoch.
"""
if self._is_epoch_to_log() and self._is_rank_to_log:
if self._is_epoch_to_log(trainer) and self._is_rank_to_log:
report_memory_usage(
f'After Train - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
f'After Train - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
self.logger)
def after_test(self):
def after_test(self, trainer):
"""Reports after testing.
"""
if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
if self._is_epoch_to_log(trainer) and self._is_rank_to_log and self._log_eval:
report_memory_usage(
f'After Test - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
f'After Test - Epoch {trainer.cur_epoch} - {self.__class__.__name__}',
self.logger)

View File

@@ -3,7 +3,6 @@ from torch import Tensor
from colossalai.builder import build_lr_scheduler
from colossalai.registry import HOOKS
from ._metric_hook import MetricHook
from .._trainer import Trainer
from ..metric import LearningRate
@@ -22,37 +21,26 @@ class LRSchedulerHook(MetricHook):
"""
def __init__(self,
trainer: Trainer,
lr_scheduler_cfg: dict,
by_epoch: bool = True,
lr_scheduler,
by_epoch: bool,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(trainer=trainer, priority=priority)
super().__init__(priority=priority)
self.by_epoch = by_epoch
self.lr_scheduler = lr_scheduler
self.store_lr_in_state = store_lr_in_state
if by_epoch:
total_steps = trainer.max_epochs
else:
total_steps = trainer.max_epochs * trainer.steps_per_epoch
if trainer.max_steps is not None:
total_steps = min(total_steps, trainer.max_steps)
def after_hook_is_attached(self, trainer):
trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=self.by_epoch,
initial_lr=self.lr_scheduler.get_last_lr()[0])
lr_scheduler_cfg['total_steps'] = total_steps
self.lr_scheduler = build_lr_scheduler(
lr_scheduler_cfg, trainer.engine.optimizer)
if store_lr_in_state:
self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch,
initial_lr=self.lr_scheduler.get_lr()[0])
def after_train_epoch(self):
def after_train_epoch(self, trainer):
if self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_last_lr()[0])

View File

@@ -5,8 +5,7 @@ from colossalai.context import ParallelMode
from colossalai.registry import HOOKS
from colossalai.utils import is_no_pp_or_last_stage
from ._base_hook import BaseHook
from .._trainer import Trainer
from ..metric import Loss, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
from ..metric import Loss, Accuracy1D, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
class MetricHook(BaseHook):
@@ -22,16 +21,14 @@ class MetricHook(BaseHook):
"""
def __init__(self,
trainer: Trainer,
priority: int,
):
super().__init__(trainer, priority)
super().__init__(priority)
self._is_stage_to_compute = is_no_pp_or_last_stage()
self._check_metric_states_initialization()
def _check_metric_states_initialization(self):
if 'metrics' not in self.trainer.states:
self.init_runner_states('metrics', dict(train={}, test={}))
def _check_metric_states_initialization(self, trainer):
if 'metrics' not in trainer.states:
self.init_runner_states(trainer, 'metrics', dict(train={}, test={}))
@HOOKS.register_module
@@ -44,36 +41,71 @@ class LossHook(MetricHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
def __init__(self, priority: int = 0):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.train_loss = Loss(epoch_only=False)
self.test_loss = Loss(epoch_only=True)
# register the metric calculator
self.trainer.states['metrics']['train'][
trainer.states['metrics']['train'][
self.train_loss.__class__.__name__] = self.train_loss
self.trainer.states['metrics']['test'][
trainer.states['metrics']['test'][
self.test_loss.__class__.__name__] = self.test_loss
def before_train_epoch(self):
def before_train_epoch(self, trainer):
if self._is_stage_to_compute:
self.train_loss.reset()
def after_train_iter(self, logits, label, loss):
def after_train_iter(self, trainer, logits, label, loss):
if self._is_stage_to_compute:
self.train_loss.update(loss)
def before_test_epoch(self):
def before_test_epoch(self, trainer):
if self._is_stage_to_compute:
self.test_loss.reset()
def after_test_iter(self, logits, label, loss):
def after_test_iter(self, trainer, logits, label, loss):
if self._is_stage_to_compute:
self.test_loss.update(loss)
@HOOKS.register_module
class Accuracy1DHook(MetricHook):
"""Specialized hook class for :class:`Accuracy1D`.
It acts the same as :class:`AccuracyHook`.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int, optional
"""
def __init__(self, priority: int = 10):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy1D(epoch_only=True)
# register the metric
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy2DHook(MetricHook):
"""Specialized hook class for :class:`Accuracy2D`.
@@ -85,42 +117,46 @@ class Accuracy2DHook(MetricHook):
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
def __init__(self, priority: int = 0):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy2D(epoch_only=True)
# register the metric
self.trainer.states['metrics']['test'][
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self):
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy2p5DHook(MetricHook):
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
def __init__(self, priority: int = 0):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
self._check_metric_states_initialization(trainer)
if self._is_stage_to_compute:
self.metric = Accuracy2p5D(epoch_only=True)
# register the metric
self.trainer.states['metrics']['test'][
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self):
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@@ -136,26 +172,22 @@ class Accuracy3DHook(MetricHook):
"""
def __init__(self,
trainer: Trainer,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
priority: int = 10):
super().__init__(trainer, priority)
super().__init__(priority)
def after_hook_is_attached(self, trainer):
if self._is_stage_to_compute:
self.metric = Accuracy3D(epoch_only=True,
input_parallel_mode=input_parallel_mode,
weight_parallel_mode=weight_parallel_mode)
self.metric = Accuracy3D(epoch_only=True)
# register the metric
self.trainer.states['metrics']['test'][
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self):
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)
@@ -170,20 +202,21 @@ class AccuracyHook(MetricHook):
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority)
def __init__(self, priority: int = 0):
super().__init__(priority)
def after_hook_is_attached(self, trainer):
if self._is_stage_to_compute:
self.metric = Accuracy(epoch_only=True)
# register the metric
self.trainer.states['metrics']['test'][
trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self):
def before_test(self, trainer):
if self._is_stage_to_compute:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
def after_test_iter(self, trainer, logits, label, *args):
if self._is_stage_to_compute:
self.metric.update(logits, label)