From 3213554cc2eb18d50583d62df1dda2193b96ba7c Mon Sep 17 00:00:00 2001 From: Jie Zhu Date: Wed, 9 Mar 2022 11:07:10 +0800 Subject: [PATCH] [profiler] add adaptive sampling to memory profiler (#330) * fix merge conflict modify unit test remove unnessesary log info reformat file * remove unused module * remove unnecessary sync function * change doc string style from Google to Sphinx --- .../engine/ophooks/_memtracer_ophook.py | 74 +++++++++++-------- 1 file changed, 42 insertions(+), 32 deletions(-) diff --git a/colossalai/engine/ophooks/_memtracer_ophook.py b/colossalai/engine/ophooks/_memtracer_ophook.py index 663d3d8a9..3ba8e536d 100644 --- a/colossalai/engine/ophooks/_memtracer_ophook.py +++ b/colossalai/engine/ophooks/_memtracer_ophook.py @@ -8,13 +8,18 @@ from time import sleep, time import pickle from typing import Optional from colossalai.core import global_context as gpc +import math def get_cuda_memory_used(device: Optional[torch.device]) -> int: - """ - Get the free memory info of device. + """Get the free memory info of device. Notice that for CPU, this function will return 1/N of the total free memory, where N is the world size. + + :param device: device id + :type device: torch.device + :return: current memory usage, sized by MB + :rtype: int """ ret: int = torch.cuda.memory_allocated(device) # get the peak memory to report correct data, so reset the counter for the next call @@ -24,13 +29,16 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int: class AsyncMemoryMonitor: - - def __init__(self, power=10): - """ - An Async Mem Monitor runing during computing. - Sampling GPU memory usage of the current GPU dev + """ + An Async Mem Monitor runing during computing. Sampling GPU memory usage of the current GPU at interval of 1/(10**power) sec. + + :param power: the power of time interval, defaults to 10 + :type power: int """ + + def __init__(self, power: int = 10): + self.keep_measuring = False self.executor = ThreadPoolExecutor(max_workers=1) self.monitor_thread = None @@ -42,6 +50,7 @@ class AsyncMemoryMonitor: return len(self.mem_stats) def set_interval(self, power: int): + self.clear() self.interval = 1 / (10**power) def is_measuring(self): @@ -89,23 +98,16 @@ class AsyncMemoryMonitor: @OPHOOKS.register_module class MemTracerOpHook(BaseOpHook): - ''' + """ Collect GPU memory usage information - Args: - warmup (int): This parameter indicates how many iterations to truncate - before profiling, e.g. set to 5 and the data will start from 6-th iteration - refreshrate (int): This parameter decides the frequency of write file. - datafile(string): the name of the stats data file - Attributes: - _warmup (int): warmup iterations - _refreshrate(int): how many iterations we shall refresh the file - _logger (colossalai.logging.logger): output log file - _curiter (int): current iteration number - _count (int): the number of times the data file was written - _data_prefix (string): the prefix of the stats data file - _rank (int): the rank of current node - ''' + :param warmup: This parameter indicates how many iterations to truncate before profiling, defaults to 50 + :type warmup: int + :param refreshrate: This parameter decides the frequency of write file, defaults to 10 + :type refreshrate: int + :param data_prefix: The prefix of the stats data file, defaults to "memstats" + :type data_prefix: string + """ def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"): super().__init__() @@ -126,6 +128,16 @@ class MemTracerOpHook(BaseOpHook): assert isinstance(module, torch.nn.Module) return module.training + def _resample(self): + # calculate the average iteration time + total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0]) + avg_it_time = total_time / self.warmup + self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s") + # adjust the sampling power + power: int = round(-math.log(avg_it_time, 10)) + 1 + self._logger.debug(f"the power is {power}") + self.async_mem_monitor.set_interval(power) + @property def refreshrate(self) -> int: return self._refreshrate @@ -146,23 +158,19 @@ class MemTracerOpHook(BaseOpHook): if self._isvalid(module): self.async_mem_monitor.finish() self.async_mem_monitor.start() - self._logger.debug(f'FWD PRE {module.__class__.__name__}') def post_fwd_exec(self, module: torch.nn.Module, *args): if self._isvalid(module): self.async_mem_monitor.finish() - self._logger.debug(f'FWD POST {module.__class__.__name__}') def pre_bwd_exec(self, module: torch.nn.Module, input, output): if self._isvalid(module): self.async_mem_monitor.finish() self.async_mem_monitor.start() - self._logger.debug(f'BWD PRE {module.__class__.__name__}') def post_bwd_exec(self, module: torch.nn.Module, input): if self._isvalid(module): self.async_mem_monitor.finish() - self._logger.debug(f'BWD POST {module.__class__.__name__}') def pre_iter(self): pass @@ -170,19 +178,21 @@ class MemTracerOpHook(BaseOpHook): def post_iter(self): self.async_mem_monitor.finish() # in the warmup stage - if self._curiter < self.warmup: - # TODO: record time and adaptively change sampling rate + if self.curiter < self.warmup: pass - elif self._curiter == self._warmup: - self.async_mem_monitor.clear() + # adjust the sampling rate + elif self.curiter == self.warmup: + # use adaptive sample rate + self._resample() + # record data to log file else: # every `refreshrate` times, refresh the file if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0: # output file info - self._logger.info(f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl') + self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl") self.save_results() self._count += 1 - self._logger.debug(f'data file has been refreshed {self._count} times') + self._logger.debug(f"data file has been refreshed {self._count} times") # finish a iteration self._curiter += 1