mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
[Gemini] clean no used MemTraceOp (#1970)
This commit is contained in:
parent
7c7921f71b
commit
7e24b9b9ee
@ -1,4 +1,3 @@
|
|||||||
from .utils import register_ophooks_recursively, BaseOpHook
|
from .utils import BaseOpHook, register_ophooks_recursively
|
||||||
from ._memtracer_ophook import MemTracerOpHook
|
|
||||||
|
|
||||||
__all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
|
__all__ = ["BaseOpHook", "register_ophooks_recursively"]
|
||||||
|
@ -1,117 +0,0 @@
|
|||||||
import json
|
|
||||||
import pickle
|
|
||||||
from pathlib import Path
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
import torch
|
|
||||||
from colossalai.gemini.ophooks import BaseOpHook
|
|
||||||
from colossalai.registry import OPHOOKS
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from typing import Union
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
@OPHOOKS.register_module
|
|
||||||
class MemTracerOpHook(BaseOpHook):
|
|
||||||
"""
|
|
||||||
Collect GPU memory usage information
|
|
||||||
|
|
||||||
Args:
|
|
||||||
warmup (int): This parameter indicates how many iterations to truncate before profiling, defaults to 50.
|
|
||||||
refreshrate (int): This parameter decides the frequency of write file, defaults to 10.
|
|
||||||
data_prefix (string): The prefix of the stats data file, defaults to "memstats".
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
|
||||||
from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
|
|
||||||
super().__init__()
|
|
||||||
self.async_mem_monitor = AsyncMemoryMonitor()
|
|
||||||
self._curiter = 0
|
|
||||||
self._logger = get_dist_logger()
|
|
||||||
self._count = 0
|
|
||||||
self._warmup = warmup
|
|
||||||
self._refreshrate = refreshrate
|
|
||||||
self._data_prefix = data_prefix
|
|
||||||
# in distributed environment
|
|
||||||
if gpc.is_initialized(ParallelMode.GLOBAL):
|
|
||||||
self._rank = gpc.get_global_rank()
|
|
||||||
else:
|
|
||||||
self._rank = 0
|
|
||||||
|
|
||||||
def _isvalid(self, module) -> bool:
|
|
||||||
assert isinstance(module, torch.nn.Module)
|
|
||||||
return module.training
|
|
||||||
|
|
||||||
def _resample(self):
|
|
||||||
# calculate the average iteration time
|
|
||||||
total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0])
|
|
||||||
avg_it_time = total_time / self.warmup
|
|
||||||
self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s")
|
|
||||||
# adjust the sampling power
|
|
||||||
power: int = round(-math.log(avg_it_time, 10)) + 1
|
|
||||||
self._logger.debug(f"the power is {power}")
|
|
||||||
self.async_mem_monitor.set_interval(power)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def refreshrate(self) -> int:
|
|
||||||
return self._refreshrate
|
|
||||||
|
|
||||||
@property
|
|
||||||
def warmup(self) -> int:
|
|
||||||
return self._warmup
|
|
||||||
|
|
||||||
@property
|
|
||||||
def curiter(self) -> int:
|
|
||||||
return self._curiter
|
|
||||||
|
|
||||||
@property
|
|
||||||
def valid_iter(self) -> int:
|
|
||||||
return self.curiter - self.warmup
|
|
||||||
|
|
||||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
|
||||||
if self._isvalid(module):
|
|
||||||
self.async_mem_monitor.finish()
|
|
||||||
self.async_mem_monitor.start()
|
|
||||||
|
|
||||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
|
||||||
if self._isvalid(module):
|
|
||||||
self.async_mem_monitor.finish()
|
|
||||||
|
|
||||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
|
||||||
if self._isvalid(module):
|
|
||||||
self.async_mem_monitor.finish()
|
|
||||||
self.async_mem_monitor.start()
|
|
||||||
|
|
||||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
|
||||||
if self._isvalid(module):
|
|
||||||
self.async_mem_monitor.finish()
|
|
||||||
|
|
||||||
def pre_iter(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def post_iter(self):
|
|
||||||
self.async_mem_monitor.finish()
|
|
||||||
# in the warmup stage
|
|
||||||
if self.curiter < self.warmup:
|
|
||||||
pass
|
|
||||||
# adjust the sampling rate
|
|
||||||
elif self.curiter == self.warmup:
|
|
||||||
# use adaptive sample rate
|
|
||||||
self._resample()
|
|
||||||
# record data to log file
|
|
||||||
else:
|
|
||||||
# every `refreshrate` times, refresh the file
|
|
||||||
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
|
||||||
# output file info
|
|
||||||
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
|
||||||
home_dir = Path.home()
|
|
||||||
with open(home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
|
|
||||||
pickle.dump(self.async_mem_monitor.state_dict, f)
|
|
||||||
self._count += 1
|
|
||||||
self._logger.debug(f"data file has been refreshed {self._count} times")
|
|
||||||
# finish a iteration
|
|
||||||
self._curiter += 1
|
|
||||||
|
|
||||||
def save_results(self, data_file: Union[str, Path]):
|
|
||||||
with open(data_file, "w") as f:
|
|
||||||
f.write(json.dumps(self.async_mem_monitor.state_dict))
|
|
@ -1,48 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
from colossalai.engine import Engine
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from colossalai.gemini.ophooks import MemTracerOpHook
|
|
||||||
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler
|
|
||||||
|
|
||||||
|
|
||||||
class MemProfiler(BaseProfiler):
|
|
||||||
"""Wraper of MemOpHook, used to show GPU memory usage through each iteration
|
|
||||||
|
|
||||||
To use this profiler, you need to pass an `engine` instance. And the usage is same like
|
|
||||||
CommProfiler.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
mm_prof = MemProfiler(engine)
|
|
||||||
with ProfilerContext([mm_prof]) as prof:
|
|
||||||
writer = SummaryWriter("mem")
|
|
||||||
engine.train()
|
|
||||||
...
|
|
||||||
prof.to_file("./log")
|
|
||||||
prof.to_tensorboard(writer)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, engine: Engine, warmup: int = 50, refreshrate: int = 10) -> None:
|
|
||||||
super().__init__(profiler_name="MemoryProfiler", priority=0)
|
|
||||||
self._mem_tracer = MemTracerOpHook(warmup=warmup, refreshrate=refreshrate)
|
|
||||||
self._engine = engine
|
|
||||||
|
|
||||||
def enable(self) -> None:
|
|
||||||
self._engine.add_hook(self._mem_tracer)
|
|
||||||
|
|
||||||
def disable(self) -> None:
|
|
||||||
self._engine.remove_hook(self._mem_tracer)
|
|
||||||
|
|
||||||
def to_tensorboard(self, writer: SummaryWriter) -> None:
|
|
||||||
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
|
|
||||||
for info, i in enumerate(stats):
|
|
||||||
writer.add_scalar("memory_usage/GPU", info, i)
|
|
||||||
|
|
||||||
def to_file(self, data_file: Path) -> None:
|
|
||||||
self._mem_tracer.save_results(data_file)
|
|
||||||
|
|
||||||
def show(self) -> None:
|
|
||||||
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
|
|
||||||
print(stats)
|
|
Loading…
Reference in New Issue
Block a user