This commit is contained in:
Jiarui Fang 2022-06-07 17:21:11 +08:00 committed by GitHub
parent 1b17859328
commit bcab249565
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 12 deletions

View File

@ -4,9 +4,7 @@
import os import os
import os.path as osp import os.path as osp
import torch
from typing import List from typing import List
from decimal import Decimal
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS from colossalai.registry import HOOKS
@ -15,6 +13,7 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._commons_ import _format_number from ._commons_ import _format_number
from colossalai.trainer.hooks._metric_hook import ThroughputMetric
class LogByEpochHook(BaseHook): class LogByEpochHook(BaseHook):
@ -53,12 +52,18 @@ class LogMetricByStepHook(BaseHook):
def after_train_iter(self, trainer, *args): def after_train_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict() trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() if isinstance(metric_calculator, ThroughputMetric):
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info()
else:
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
def after_test_iter(self, trainer, *args): def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict() trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() if isinstance(metric_calculator, ThroughputMetric):
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info()
else:
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
@HOOKS.register_module @HOOKS.register_module

View File

@ -52,7 +52,7 @@ class Metric(ABC):
pass pass
@abstractmethod @abstractmethod
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
"""Returns the metric value in the last iteration. """Returns the metric value in the last iteration.
""" """
pass pass
@ -121,10 +121,10 @@ class LossMetric(Metric):
self.accum_loss.div_(self.count) self.accum_loss.div_(self.count)
return self.accum_loss.item() return self.accum_loss.item()
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
"""Returns :attr:`last_step_loss`. """Returns :attr:`last_step_loss`.
""" """
return str(self.last_step_loss.cpu().item()) return self.last_step_loss.cpu().item()
@staticmethod @staticmethod
def is_better(a, b): def is_better(a, b):
@ -149,8 +149,8 @@ class LearningRateMetric(Metric):
def update(self, lr) -> None: def update(self, lr) -> None:
self.lr = lr self.lr = lr
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
return str(self.lr) return self.lr
def get_accumulated_value(self): def get_accumulated_value(self):
return self.lr return self.lr
@ -204,10 +204,10 @@ class AccuracyMetric(Metric):
self.accumulated_sum += self.last_step_sum self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct self.accumulated_correct += self.last_step_correct
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA) self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA) self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
return str(_format_number((self.last_step_correct / self.last_step_sum).cpu().item())) return _format_number((self.last_step_correct / self.last_step_sum).cpu().item())
def get_accumulated_value(self): def get_accumulated_value(self):
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA) self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
@ -350,7 +350,18 @@ class ThroughputMetric(Metric):
self.accumulated_num_samples += self.last_step_num_samples self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time self.accumulated_used_time += self.last_step_used_time
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
if self._use_local:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
return sample_per_sec
def get_last_step_info(self) -> str:
if self._use_local: if self._use_local:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else: else: