#!/usr/bin/env python # -*- encoding: utf-8 -*- import os import os.path as osp import torch from torch.utils.tensorboard import SummaryWriter from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.registry import HOOKS from colossalai.trainer._trainer import Trainer from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \ is_tp_rank_0, is_no_pp_or_last_stage from ._base_hook import BaseHook def _format_number(val): if isinstance(val, float): return f'{val:.5f}' elif torch.is_floating_point(val): return f'{val.item():.5f}' return val class EpochIntervalHook(BaseHook): def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1): super().__init__(trainer, priority) self._interval = interval def _is_epoch_to_log(self): return self.trainer.cur_epoch % self._interval == 0 @HOOKS.register_module class LogMetricByEpochHook(EpochIntervalHook): """Specialized Hook to record the metric to log. :param trainer: Trainer attached with current hook :type trainer: Trainer :param interval: Recording interval :type interval: int, optional :param priority: Priority in the printing, hooks with small priority will be printed in front :type priority: int, optional """ def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None: super().__init__(trainer=trainer, interval=interval, priority=priority) self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() def _get_str(self, mode): msg = [] for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items(): msg.append( f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}') msg = ', '.join(msg) return msg def after_train_epoch(self): if self._is_epoch_to_log(): msg = self._get_str(mode='train') if self._is_rank_to_log: self.logger.info( f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}') def after_test_epoch(self): if self._is_epoch_to_log(): msg = self._get_str(mode='test') if self._is_rank_to_log: self.logger.info( f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}') @HOOKS.register_module class TensorboardHook(BaseHook): """Specialized Hook to record the metric to Tensorboard. :param trainer: Trainer attached with current hook :type trainer: Trainer :param log_dir: Directory of log :type log_dir: str, optional :param priority: Priority in the printing, hooks with small priority will be printed in front :type priority: int, optional """ def __init__(self, trainer: Trainer, log_dir: str, dp_rank_0_only: bool = True, tp_rank_0_only: bool = True, priority: int = 10, ) -> None: super().__init__(trainer=trainer, priority=priority) # create log dir if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: os.makedirs(log_dir, exist_ok=True) # determine the ranks to generate tensorboard logs self._is_valid_rank_to_log = is_no_pp_or_last_stage() if dp_rank_0_only: self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0() if tp_rank_0_only: self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0() if self._is_valid_rank_to_log: # create workspace on only one rank if gpc.is_initialized(ParallelMode.GLOBAL): rank = gpc.get_global_rank() else: rank = 0 # create workspace log_dir = osp.join(log_dir, f'rank_{rank}') os.makedirs(log_dir, exist_ok=True) self.writer = SummaryWriter( log_dir=log_dir, filename_suffix=f'_rank_{rank}') def _log_by_iter(self, mode: str): for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items(): if metric_calculator.epoch_only: continue val = metric_calculator.get_last_step_value() if self._is_valid_rank_to_log: self.writer.add_scalar(f'{metric_name}/{mode}', val, self.trainer.cur_step) def _log_by_epoch(self, mode: str): for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items(): if metric_calculator.epoch_only: val = metric_calculator.get_accumulated_value() if self._is_valid_rank_to_log: self.writer.add_scalar(f'{metric_name}/{mode}', val, self.trainer.cur_step) def after_test_iter(self, *args): self._log_by_iter(mode='test') def after_test_epoch(self): self._log_by_epoch(mode='test') def after_train_iter(self, *args): self._log_by_iter(mode='train') def after_train_epoch(self): self._log_by_epoch(mode='train') @HOOKS.register_module class LogTimingByEpochHook(EpochIntervalHook): """Specialized Hook to write timing record to log. :param trainer: Trainer attached with current hook :type trainer: Trainer :param interval: Recording interval :type interval: int, optional :param priority: Priority in the printing, hooks with small priority will be printed in front :type priority: int, optional :param log_eval: Whether writes in evaluation :type log_eval: bool, optional """ def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10, log_eval: bool = True ) -> None: super().__init__(trainer=trainer, interval=interval, priority=priority) set_global_multitimer_status(True) self._global_timer = get_global_multitimer() self._log_eval = log_eval self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() def _get_message(self): msg = [] for timer_name, timer in self._global_timer: last_elapsed_time = timer.get_elapsed_time() if timer.has_history: history_mean = timer.get_history_mean() history_sum = timer.get_history_sum() msg.append( f'{timer_name}: last elapsed time = {last_elapsed_time}, ' f'history sum = {history_sum}, history mean = {history_mean}') else: msg.append( f'{timer_name}: last elapsed time = {last_elapsed_time}') msg = ', '.join(msg) return msg def after_train_epoch(self): """Writes log after finishing a training epoch. """ if self._is_epoch_to_log() and self._is_rank_to_log: msg = self._get_message() self.logger.info( f'Training - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}') def after_test_epoch(self): """Writes log after finishing a testing epoch. """ if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval: msg = self._get_message() self.logger.info( f'Testing - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}: {msg}') @HOOKS.register_module class LogMemoryByEpochHook(EpochIntervalHook): """Specialized Hook to write memory usage record to log. :param trainer: Trainer attached with current hook :type trainer: Trainer :param interval: Recording interval :type interval: int, optional :param priority: Priority in the printing, hooks with small priority will be printed in front :type priority: int, optional :param log_eval: Whether writes in evaluation :type log_eval: bool, optional """ def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10, log_eval: bool = True ) -> None: super().__init__(trainer=trainer, interval=interval, priority=priority) set_global_multitimer_status(True) self._global_timer = get_global_multitimer() self._log_eval = log_eval self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() def before_train(self): """Resets before training. """ if self._is_epoch_to_log() and self._is_rank_to_log: report_memory_usage('before-train') def after_train_epoch(self): """Writes log after finishing a training epoch. """ if self._is_epoch_to_log() and self._is_rank_to_log: report_memory_usage( f'After Train - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}') def after_test(self): """Reports after testing. """ if self._is_epoch_to_log() and self._is_rank_to_log and self._log_eval: report_memory_usage( f'After Test - Epoch {self.trainer.cur_epoch} - {self.__class__.__name__}')