mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[refactor] moving memtracer to gemini (#801)
This commit is contained in:
9
colossalai/trainer/hooks/_commons_.py
Normal file
9
colossalai/trainer/hooks/_commons_.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import torch
|
||||
|
||||
|
||||
def _format_number(val, prec=5):
|
||||
if isinstance(val, float):
|
||||
return f'{val:.{prec}g}'
|
||||
elif torch.is_tensor(val) and torch.is_floating_point(val):
|
||||
return f'{val.item():.{prec}g}'
|
||||
return val
|
@@ -14,14 +14,7 @@ from colossalai.logging import DistributedLogger
|
||||
from colossalai.utils import report_memory_usage, is_dp_rank_0, \
|
||||
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
|
||||
from ._base_hook import BaseHook
|
||||
|
||||
|
||||
def _format_number(val, prec=5):
|
||||
if isinstance(val, float):
|
||||
return f'{val:.{prec}g}'
|
||||
elif torch.is_tensor(val) and torch.is_floating_point(val):
|
||||
return f'{val.item():.{prec}g}'
|
||||
return val
|
||||
from ._commons_ import _format_number
|
||||
|
||||
|
||||
class LogByEpochHook(BaseHook):
|
||||
@@ -35,10 +28,7 @@ class LogByEpochHook(BaseHook):
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger,
|
||||
interval: int = 1,
|
||||
priority: int = 1):
|
||||
def __init__(self, logger, interval: int = 1, priority: int = 1):
|
||||
super().__init__(priority)
|
||||
self.logger = logger
|
||||
self._interval = interval
|
||||
@@ -63,14 +53,12 @@ class LogMetricByStepHook(BaseHook):
|
||||
def after_train_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
|
||||
trainer.states['step_metrics'][metric_name.lower()] = \
|
||||
f'{_format_number(metric_calculator.get_last_step_value())}'
|
||||
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
trainer.states['step_metrics'] = dict()
|
||||
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
|
||||
trainer.states['step_metrics'][metric_name.lower()] = \
|
||||
f'{_format_number(metric_calculator.get_last_step_value())}'
|
||||
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
@@ -85,18 +73,14 @@ class LogMetricByEpochHook(LogByEpochHook):
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger,
|
||||
interval: int = 1,
|
||||
priority: int = 10) -> None:
|
||||
def __init__(self, logger, interval: int = 1, priority: int = 10) -> None:
|
||||
super().__init__(logger, interval, priority)
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
|
||||
|
||||
def _get_str(self, trainer, mode):
|
||||
msg = []
|
||||
for metric_name, metric_calculator in trainer.states['metrics'][mode].items():
|
||||
msg.append(
|
||||
f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
|
||||
msg.append(f'{metric_name} = {_format_number(metric_calculator.get_accumulated_value())}')
|
||||
msg = ' | '.join(msg)
|
||||
return msg
|
||||
|
||||
@@ -130,12 +114,13 @@ class TensorboardHook(BaseHook):
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
log_dir: str,
|
||||
ranks: List = None,
|
||||
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
|
||||
priority: int = 10,
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
log_dir: str,
|
||||
ranks: List = None,
|
||||
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
|
||||
priority: int = 10,
|
||||
) -> None:
|
||||
super().__init__(priority=priority)
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@@ -280,13 +265,14 @@ class LogMemoryByEpochHook(LogByEpochHook):
|
||||
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
report_cpu: bool = False, # no reference
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
logger: DistributedLogger,
|
||||
interval: int = 1,
|
||||
priority: int = 10,
|
||||
log_eval: bool = True,
|
||||
report_cpu: bool = False, # no reference
|
||||
) -> None:
|
||||
super().__init__(logger=logger, interval=interval, priority=priority)
|
||||
self._log_eval = log_eval
|
||||
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0()
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from colossalai.registry import HOOKS
|
||||
from torch import Tensor
|
||||
from colossalai.trainer.hooks import BaseHook
|
||||
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
||||
from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
|
||||
|
||||
|
||||
@HOOKS.register_module
|
||||
|
@@ -13,6 +13,7 @@ from colossalai.registry import HOOKS
|
||||
from colossalai.utils import get_current_device, is_no_pp_or_last_stage
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
from ._commons_ import _format_number
|
||||
|
||||
|
||||
class Metric(ABC):
|
||||
@@ -51,7 +52,7 @@ class Metric(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_last_step_value(self):
|
||||
def get_last_step_value(self) -> str:
|
||||
"""Returns the metric value in the last iteration.
|
||||
"""
|
||||
pass
|
||||
@@ -120,10 +121,10 @@ class LossMetric(Metric):
|
||||
self.accum_loss.div_(self.count)
|
||||
return self.accum_loss.item()
|
||||
|
||||
def get_last_step_value(self):
|
||||
def get_last_step_value(self) -> str:
|
||||
"""Returns :attr:`last_step_loss`.
|
||||
"""
|
||||
return self.last_step_loss
|
||||
return str(self.last_step_loss)
|
||||
|
||||
@staticmethod
|
||||
def is_better(a, b):
|
||||
@@ -148,8 +149,8 @@ class LearningRateMetric(Metric):
|
||||
def update(self, lr) -> None:
|
||||
self.lr = lr
|
||||
|
||||
def get_last_step_value(self):
|
||||
return self.lr
|
||||
def get_last_step_value(self) -> str:
|
||||
return str(self.lr)
|
||||
|
||||
def get_accumulated_value(self):
|
||||
return self.lr
|
||||
@@ -203,10 +204,10 @@ class AccuracyMetric(Metric):
|
||||
self.accumulated_sum += self.last_step_sum
|
||||
self.accumulated_correct += self.last_step_correct
|
||||
|
||||
def get_last_step_value(self):
|
||||
def get_last_step_value(self) -> str:
|
||||
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
|
||||
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
|
||||
return (self.last_step_correct / self.last_step_sum).item()
|
||||
return str(_format_number((self.last_step_correct / self.last_step_sum).item()))
|
||||
|
||||
def get_accumulated_value(self):
|
||||
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
|
||||
@@ -322,7 +323,8 @@ class ThroughputMetric(Metric):
|
||||
Args:
|
||||
epoch_only (bool): Whether the metric only read for the full epoch.
|
||||
"""
|
||||
def __init__(self, epoch_only: bool, ignored_steps: int = 0):
|
||||
|
||||
def __init__(self, epoch_only: bool, ignored_steps: int = 0, tflop_per_step: int = 0):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.ignored_steps = ignored_steps
|
||||
self.cur_steps = 0
|
||||
@@ -330,6 +332,7 @@ class ThroughputMetric(Metric):
|
||||
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_used_time = torch.zeros(1, device=get_current_device())
|
||||
self._tflop_per_step = tflop_per_step
|
||||
|
||||
def reset(self) -> None:
|
||||
# self.cur_steps = 0
|
||||
@@ -346,13 +349,18 @@ class ThroughputMetric(Metric):
|
||||
self.accumulated_num_samples += self.last_step_num_samples
|
||||
self.accumulated_used_time += self.last_step_used_time
|
||||
|
||||
def get_last_step_value(self):
|
||||
def get_last_step_value(self) -> str:
|
||||
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
|
||||
return (self.last_step_num_samples / (self.last_step_used_time + 1e-12)).item()
|
||||
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
|
||||
if self._tflop_per_step > 0:
|
||||
tflops = _format_number(self._tflop_per_step / (self.last_step_used_time.item() + 1e-12))
|
||||
return f"{sample_per_sec} sample_per_sec, {tflops} Tflops"
|
||||
else:
|
||||
return f"{sample_per_sec} sample_per_sec"
|
||||
|
||||
def get_accumulated_value(self):
|
||||
def get_accumulated_value(self) -> float:
|
||||
self.accumulated_used_time = all_reduce(self.accumulated_used_time, ParallelMode.DATA) / \
|
||||
gpc.get_world_size(ParallelMode.DATA)
|
||||
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
|
||||
@@ -373,14 +381,18 @@ class ThroughputHook(MetricHook):
|
||||
defaults to 10. If different hooks share same priority, the order of printing would
|
||||
depend on the hooks order in the hook list.
|
||||
"""
|
||||
def __init__(self, ignored_steps: int = 0, priority: int = 10):
|
||||
|
||||
def __init__(self, ignored_steps: int = 0, priority: int = 10, tflop_per_step: int = 0):
|
||||
super().__init__(priority)
|
||||
self.ignored_steps = ignored_steps
|
||||
self._tflop_per_step = tflop_per_step
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
self._check_metric_states_initialization(trainer)
|
||||
if self._is_stage_to_compute:
|
||||
self.metric = ThroughputMetric(epoch_only=True, ignored_steps=self.ignored_steps)
|
||||
self.metric = ThroughputMetric(epoch_only=True,
|
||||
ignored_steps=self.ignored_steps,
|
||||
tflop_per_step=self._tflop_per_step)
|
||||
|
||||
# register the metric
|
||||
trainer.states['metrics']['train']['Throughput'] = self.metric
|
||||
@@ -392,7 +404,8 @@ class ThroughputHook(MetricHook):
|
||||
|
||||
def after_train_iter(self, trainer, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
self.metric.update(trainer.engine.schedule.batch_size,
|
||||
trainer._timer.get_timer('Train-step').get_elapsed_time())
|
||||
|
||||
def before_test(self, trainer):
|
||||
if self._is_stage_to_compute:
|
||||
@@ -400,4 +413,5 @@ class ThroughputHook(MetricHook):
|
||||
|
||||
def after_test_iter(self, trainer, *args):
|
||||
if self._is_stage_to_compute:
|
||||
self.metric.update(trainer.engine.schedule.batch_size, trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
self.metric.update(trainer.engine.schedule.batch_size,
|
||||
trainer._timer.get_timer('Test-step').get_elapsed_time())
|
||||
|
Reference in New Issue
Block a user