Migrated project

This commit is contained in:
zbian
2021-10-28 18:21:23 +02:00
parent 2ebaefc542
commit 404ecbdcc6
409 changed files with 35853 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
__all__ = [
'BaseHook', 'MetricHook',
'LoadCheckpointHook', 'SaveCheckpointHook',
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
]

View File

@@ -0,0 +1,107 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC
from torch import Tensor
from colossalai.logging import get_global_dist_logger
from .._trainer import Trainer
class BaseHook(ABC):
"""This class allows users to add desired actions in specific time points
during training or evaluation.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int) -> None:
self.trainer = trainer
self.priority = priority
self.logger = get_global_dist_logger()
def before_train(self):
"""Actions before training.
"""
pass
def after_train(self):
"""Actions after training.
"""
pass
def before_train_iter(self):
"""Actions before running a training iteration.
"""
pass
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a training iteration.
:param output: Output of the model
:param label: Labels of the input data
:param loss: Loss between the output and input data
:type output: Tensor
:type label: Tensor
:type loss: Tensor
"""
pass
def before_train_epoch(self):
"""Actions before starting a training epoch.
"""
pass
def after_train_epoch(self):
"""Actions after finishing a training epoch.
"""
pass
def before_test(self):
"""Actions before evaluation.
"""
pass
def after_test(self):
"""Actions after evaluation.
"""
pass
def before_test_epoch(self):
"""Actions before starting a testing epoch.
"""
pass
def after_test_epoch(self):
"""Actions after finishing a testing epoch.
"""
pass
def before_test_iter(self):
"""Actions before running a testing iteration.
"""
pass
def after_test_iter(self, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a testing iteration.
:param output: Output of the model
:param label: Labels of the input data
:param loss: Loss between the output and input data
:type output: Tensor
:type label: Tensor
:type loss: Tensor
"""
pass
def init_runner_states(self, key, val):
"""Initializes trainer's state.
:param key: Key of reseting state
:param val: Value of reseting state
"""
if key not in self.trainer.states:
self.trainer.states[key] = val

View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import torch.distributed as dist
from colossalai.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
from colossalai.registry import HOOKS
from colossalai.trainer.hooks import BaseHook
from colossalai.trainer import Trainer
from colossalai.utils import is_dp_rank_0
@HOOKS.register_module
class SaveCheckpointHook(BaseHook):
"""Saves the model by interval in training process.
:param trainer: Trainer attached with current hook
:param interval: Saving interval
:param checkpoint_dir: Directory of saving checkpoint
:param suffix: Saving suffix of the file
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type interval: int, optional
:type checkpoint_dir: int, optional
:type suffix: str, optional
:type priority: int, optional
"""
def __init__(self,
trainer: Trainer,
interval: int = 1,
checkpoint_dir: str = None,
suffix: str = '',
priority: int = 0):
super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
self.interval = interval
self.checkpoint_dir = checkpoint_dir
self.suffix = suffix
def after_train_epoch(self):
"""Saves the model after a training epoch.
"""
# save by interval
if self.trainer.cur_epoch % self.interval == 0:
# only gpus with data parallel rank equals to 0 write to the disk
if is_dp_rank_0():
self.trainer.save(path=self.checkpoint_dir, suffix=self.suffix)
self.logger.info(
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
# wait until everyone is done
if dist.is_initialized():
dist.barrier()
@HOOKS.register_module
class LoadCheckpointHook(BaseHook):
"""Loads the model before training process.
:param trainer: Trainer attached with current hook
:param checkpoint_dir: Directory of saving checkpoint
:param epoch: Epoch number to be set
:param finetune: Whether allows to load a part of the model
:param strict: Whether loads a model that has the same shape of parameters
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type checkpoint_dir: str, optional
:type epoch: str, optional
:type finetune: bool, optional
:type strict: bool, optional
:type priority: int, optional
"""
def __init__(self,
trainer: Trainer = None,
checkpoint_dir: str = None,
epoch: int = -1,
finetune: bool = False,
strict: bool = False,
priority: int = 10) -> None:
assert isinstance(trainer, Trainer), \
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
self.epoch = epoch
self.checkpoint_dir = checkpoint_dir
self.finetune = finetune
self.strict = strict
super().__init__(trainer=trainer, priority=priority)
def before_train(self):
"""Loads parameters to the model before training.
"""
if self.epoch == -1:
path = get_latest_checkpoint_path(self.checkpoint_dir)
else:
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch)
if osp.exists(path):
self.trainer.load(
path, finetune=self.finetune, strict=self.strict)
self.logger.info(
f'loaded checkpoint from {path}')
else:
raise FileNotFoundError(f'checkpoint is not found at {path}')
# Some utilities want to load a checkpoint without distributed being initialized
if dist.is_initialized():
dist.barrier()

View File

@@ -0,0 +1,247 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import os.path as osp
import torch
from tensorboardX import SummaryWriter
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS
from colossalai.trainer._trainer import Trainer
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage
from ._metric_hook import MetricHook
def _format_number(val):
if isinstance(val, float):
return f'{val:.5f}'
elif torch.is_floating_point(val):
return f'{val.item():.5f}'
return val
class EpochIntervalHook(MetricHook):
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
super().__init__(trainer, priority)
self._interval = interval
def _is_epoch_to_log(self):
return self.trainer.cur_epoch % self._interval == 0
@HOOKS.register_module
class LogMetricByEpochHook(EpochIntervalHook):
"""Specialized Hook to record the metric to log.
:param trainer: Trainer attached with current hook
:type trainer: Trainer
:param interval: Recording interval
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
def _get_str(self, mode):
msg = []
for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
msg.append(
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
msg = ', '.join(msg)
return msg
def after_train_epoch(self):
if self._is_epoch_to_log():
msg = self._get_str(mode='train')
if self._is_rank_to_log:
self.logger.info(
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
def after_test_epoch(self):
if self._is_epoch_to_log():
msg = self._get_str(mode='test')
if self._is_rank_to_log:
self.logger.info(
f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
@HOOKS.register_module
class TensorboardHook(MetricHook):
"""Specialized Hook to record the metric to Tensorboard.
:param trainer: Trainer attached with current hook
:type trainer: Trainer
:param log_dir: Directory of log
:type log_dir: str, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None:
super().__init__(trainer=trainer, priority=priority)
self._is_rank_to_log = is_no_pp_or_last_stage()
if self._is_rank_to_log:
# create workspace on only one rank
if gpc.is_initialized(ParallelMode.GLOBAL):
rank = gpc.get_global_rank()
else:
rank = 0
log_dir = osp.join(log_dir, f'rank_{rank}')
# create workspace
if not osp.exists(log_dir):
os.makedirs(log_dir)
self.writer = SummaryWriter(
log_dir=log_dir, filename_suffix=f'_rank_{rank}')
def after_train_iter(self, *args):
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_rank_to_log:
self.writer.add_scalar(
f'{metric_name}/train', val, self.trainer.cur_step)
def after_test_iter(self, *args):
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/test', val,
self.trainer.cur_step)
def after_test_epoch(self):
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/test', val,
self.trainer.cur_step)
def after_train_epoch(self):
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items():
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/train', val,
self.trainer.cur_step)
@HOOKS.register_module
class LogTimingByEpochHook(EpochIntervalHook):
"""Specialized Hook to write timing record to log.
:param trainer: Trainer attached with current hook
:type trainer: Trainer
:param interval: Recording interval
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
:param log_eval: Whether writes in evaluation
:type log_eval: bool, optional
"""
def __init__(self,
trainer: Trainer,
interval: int = 1,
priority: int = 1,
log_eval: bool = True
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
set_global_multitimer_status(True)
self._global_timer = get_global_multitimer()
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
def _get_message(self):
msg = []
for timer_name, timer in self._global_timer:
last_elapsed_time = timer.get_elapsed_time()
if timer.has_history:
history_mean = timer.get_history_mean()
history_sum = timer.get_history_sum()
msg.append(
f'{timer_name}: last elapsed time = {last_elapsed_time}, '
f'history sum = {history_sum}, history mean = {history_mean}')
else:
msg.append(
f'{timer_name}: last elapsed time = {last_elapsed_time}')
msg = ', '.join(msg)
return msg
def after_train_epoch(self):
"""Writes log after finishing a training epoch.
"""
if self._is_epoch_to_log() and self._is_rank_to_log:
msg = self._get_message()
self.logger.info(
f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
def after_test_epoch(self):
"""Writes log after finishing a testing epoch.
"""
if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
msg = self._get_message()
self.logger.info(
f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}')
@HOOKS.register_module
class LogMemoryByEpochHook(EpochIntervalHook):
"""Specialized Hook to write memory usage record to log.
:param trainer: Trainer attached with current hook
:type trainer: Trainer
:param interval: Recording interval
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
:param log_eval: Whether writes in evaluation
:type log_eval: bool, optional
"""
def __init__(self,
trainer: Trainer,
interval: int = 1,
priority: int = 1,
log_eval: bool = True
) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority)
set_global_multitimer_status(True)
self._global_timer = get_global_multitimer()
self._log_eval = log_eval
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
def before_train(self):
"""Resets before training.
"""
if self._is_epoch_to_log() and self._is_rank_to_log:
report_memory_usage('before-train')
def after_train_epoch(self):
"""Writes log after finishing a training epoch.
"""
if self._is_epoch_to_log() and self._is_rank_to_log:
report_memory_usage(
f'After Train - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')
def after_test(self):
"""Reports after testing.
"""
if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval:
report_memory_usage(
f'After Test - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')

View File

@@ -0,0 +1,185 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.context import ParallelMode
from colossalai.registry import HOOKS
from colossalai.utils import is_no_pp_or_last_stage
from ._base_hook import BaseHook
from .._trainer import Trainer
from ..metric import Loss, Accuracy2D, Accuracy, Accuracy2p5D, Accuracy3D
class MetricHook(BaseHook):
"""Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and
update their states. Others are used to display and
record the metric.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int):
super().__init__(trainer, priority)
self._is_stage_to_log = is_no_pp_or_last_stage()
self._check_metric_states_initialization()
def _check_metric_states_initialization(self):
if 'metrics' not in self.trainer.states:
self.init_runner_states('metrics', dict(train={}, test={}))
@HOOKS.register_module
class LossHook(MetricHook):
"""Specialized hook class for :class:`Loss`.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, priority: int = 10):
super().__init__(trainer, priority)
if self._is_stage_to_log:
self.metric = Loss(epoch_only=False)
# register the metric calculator
self.trainer.states['metrics']['train'][
self.metric.__class__.__name__] = self.metric
self.trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_train_epoch(self):
if self._is_stage_to_log:
self.metric.reset()
def after_train_iter(self, logits, label, loss):
if self._is_stage_to_log:
self.metric.update(loss)
def before_test_epoch(self):
if self._is_stage_to_log:
self.metric.reset()
def after_test_iter(self, logits, label, loss):
if self._is_stage_to_log:
self.metric.update(loss)
@HOOKS.register_module
class Accuracy2DHook(MetricHook):
"""Specialized hook class for :class:`Accuracy2D`.
It acts the same as :class:`AccuracyHook`.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int, optional
"""
def __init__(self, trainer: Trainer, priority: int = 10):
super().__init__(trainer, priority)
if self._is_stage_to_log:
self.metric = Accuracy2D(epoch_only=True)
# register the metric
self.trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy2p5DHook(MetricHook):
def __init__(self, trainer: Trainer, priority: int = 10):
super().__init__(trainer, priority)
if self._is_stage_to_log:
self.metric = Accuracy2p5D(epoch_only=True)
# register the metric
self.trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
self.metric.update(logits, label)
@HOOKS.register_module
class Accuracy3DHook(MetricHook):
"""Specialized hook class for :class:`Accuracy3D`.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int
"""
def __init__(self,
trainer: Trainer,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
priority: int = 10):
super().__init__(trainer, priority)
if self._is_stage_to_log:
self.metric = Accuracy3D(epoch_only=True,
input_parallel_mode=input_parallel_mode,
weight_parallel_mode=weight_parallel_mode)
# register the metric
self.trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
self.metric.update(logits, label)
@HOOKS.register_module
class AccuracyHook(MetricHook):
"""Specialized hook class for :class:`Accuracy`.
:param trainer: Trainer attached with current hook
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type trainer: Trainer
:type priority: int
"""
def __init__(self, trainer: Trainer, priority: int = 10):
super().__init__(trainer, priority)
if self._is_stage_to_log:
self.metric = Accuracy(epoch_only=True)
# register the metric
self.trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric
def before_test(self):
if self._is_stage_to_log:
self.metric.reset()
def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log:
self.metric.update(logits, label)