mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
[profiler] add MemProfiler (#356)
* add memory trainer hook * fix bug * add memory trainer hook * fix import bug * fix import bug * add trainer hook * fix #370 git log bug * modify `to_tensorboard` function to support better output * remove useless output * change the name of `MemProfiler` * complete memory profiler * replace error with warning * finish trainer hook * modify interface of MemProfiler * modify `__init__.py` in profiler * remove unnecessary pass statement * add usage to doc string * add usage to trainer hook * new location to store temp data file
This commit is contained in:
44
colossalai/trainer/hooks/_mem_tracer_hook.py
Normal file
44
colossalai/trainer/hooks/_mem_tracer_hook.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from cgitb import Hook
|
||||
from colossalai.registry import HOOKS
|
||||
from torch import Tensor
|
||||
from colossalai.trainer.hooks import BaseHook
|
||||
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
||||
from ._metric_hook import LearningRateMetric, MetricHook
|
||||
|
||||
@HOOKS.register_module
|
||||
class MemTraceHook(BaseHook):
|
||||
"""Save memory stats and pass it to states
|
||||
This hook is used to record memory usage info, and pass to trainer.states
|
||||
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
super().__init__(priority=priority)
|
||||
self._memory_monitor = AsyncMemoryMonitor()
|
||||
|
||||
def after_hook_is_attached(self, trainer):
|
||||
# Initialize the data
|
||||
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
||||
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
self._memory_monitor.start()
|
||||
return super().before_train_iter(trainer)
|
||||
|
||||
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
self._memory_monitor.finish()
|
||||
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
||||
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
||||
return super().after_train_iter(trainer, output, label, loss)
|
||||
|
||||
def before_test_iter(self, trainer):
|
||||
self._memory_monitor.start()
|
||||
return super().before_test(trainer)
|
||||
|
||||
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||
self._memory_monitor.finish()
|
||||
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
||||
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
||||
return super().after_test_iter(trainer, output, label, loss)
|
Reference in New Issue
Block a user