mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 12:12:46 +00:00
[profiler] add adaptive sampling to memory profiler (#330)
* fix merge conflict modify unit test remove unnessesary log info reformat file * remove unused module * remove unnecessary sync function * change doc string style from Google to Sphinx
This commit is contained in:
parent
1388671699
commit
3213554cc2
@ -8,13 +8,18 @@ from time import sleep, time
|
|||||||
import pickle
|
import pickle
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_memory_used(device: Optional[torch.device]) -> int:
|
def get_cuda_memory_used(device: Optional[torch.device]) -> int:
|
||||||
"""
|
"""Get the free memory info of device.
|
||||||
Get the free memory info of device.
|
|
||||||
Notice that for CPU, this function will return 1/N of the total free memory,
|
Notice that for CPU, this function will return 1/N of the total free memory,
|
||||||
where N is the world size.
|
where N is the world size.
|
||||||
|
|
||||||
|
:param device: device id
|
||||||
|
:type device: torch.device
|
||||||
|
:return: current memory usage, sized by MB
|
||||||
|
:rtype: int
|
||||||
"""
|
"""
|
||||||
ret: int = torch.cuda.memory_allocated(device)
|
ret: int = torch.cuda.memory_allocated(device)
|
||||||
# get the peak memory to report correct data, so reset the counter for the next call
|
# get the peak memory to report correct data, so reset the counter for the next call
|
||||||
@ -24,13 +29,16 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
|
|||||||
|
|
||||||
|
|
||||||
class AsyncMemoryMonitor:
|
class AsyncMemoryMonitor:
|
||||||
|
|
||||||
def __init__(self, power=10):
|
|
||||||
"""
|
"""
|
||||||
An Async Mem Monitor runing during computing.
|
An Async Mem Monitor runing during computing. Sampling GPU memory usage of the current GPU
|
||||||
Sampling GPU memory usage of the current GPU dev
|
|
||||||
at interval of 1/(10**power) sec.
|
at interval of 1/(10**power) sec.
|
||||||
|
|
||||||
|
:param power: the power of time interval, defaults to 10
|
||||||
|
:type power: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, power: int = 10):
|
||||||
|
|
||||||
self.keep_measuring = False
|
self.keep_measuring = False
|
||||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
self.monitor_thread = None
|
self.monitor_thread = None
|
||||||
@ -42,6 +50,7 @@ class AsyncMemoryMonitor:
|
|||||||
return len(self.mem_stats)
|
return len(self.mem_stats)
|
||||||
|
|
||||||
def set_interval(self, power: int):
|
def set_interval(self, power: int):
|
||||||
|
self.clear()
|
||||||
self.interval = 1 / (10**power)
|
self.interval = 1 / (10**power)
|
||||||
|
|
||||||
def is_measuring(self):
|
def is_measuring(self):
|
||||||
@ -89,23 +98,16 @@ class AsyncMemoryMonitor:
|
|||||||
|
|
||||||
@OPHOOKS.register_module
|
@OPHOOKS.register_module
|
||||||
class MemTracerOpHook(BaseOpHook):
|
class MemTracerOpHook(BaseOpHook):
|
||||||
'''
|
"""
|
||||||
Collect GPU memory usage information
|
Collect GPU memory usage information
|
||||||
|
|
||||||
Args:
|
:param warmup: This parameter indicates how many iterations to truncate before profiling, defaults to 50
|
||||||
warmup (int): This parameter indicates how many iterations to truncate
|
:type warmup: int
|
||||||
before profiling, e.g. set to 5 and the data will start from 6-th iteration
|
:param refreshrate: This parameter decides the frequency of write file, defaults to 10
|
||||||
refreshrate (int): This parameter decides the frequency of write file.
|
:type refreshrate: int
|
||||||
datafile(string): the name of the stats data file
|
:param data_prefix: The prefix of the stats data file, defaults to "memstats"
|
||||||
Attributes:
|
:type data_prefix: string
|
||||||
_warmup (int): warmup iterations
|
"""
|
||||||
_refreshrate(int): how many iterations we shall refresh the file
|
|
||||||
_logger (colossalai.logging.logger): output log file
|
|
||||||
_curiter (int): current iteration number
|
|
||||||
_count (int): the number of times the data file was written
|
|
||||||
_data_prefix (string): the prefix of the stats data file
|
|
||||||
_rank (int): the rank of current node
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -126,6 +128,16 @@ class MemTracerOpHook(BaseOpHook):
|
|||||||
assert isinstance(module, torch.nn.Module)
|
assert isinstance(module, torch.nn.Module)
|
||||||
return module.training
|
return module.training
|
||||||
|
|
||||||
|
def _resample(self):
|
||||||
|
# calculate the average iteration time
|
||||||
|
total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0])
|
||||||
|
avg_it_time = total_time / self.warmup
|
||||||
|
self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s")
|
||||||
|
# adjust the sampling power
|
||||||
|
power: int = round(-math.log(avg_it_time, 10)) + 1
|
||||||
|
self._logger.debug(f"the power is {power}")
|
||||||
|
self.async_mem_monitor.set_interval(power)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def refreshrate(self) -> int:
|
def refreshrate(self) -> int:
|
||||||
return self._refreshrate
|
return self._refreshrate
|
||||||
@ -146,23 +158,19 @@ class MemTracerOpHook(BaseOpHook):
|
|||||||
if self._isvalid(module):
|
if self._isvalid(module):
|
||||||
self.async_mem_monitor.finish()
|
self.async_mem_monitor.finish()
|
||||||
self.async_mem_monitor.start()
|
self.async_mem_monitor.start()
|
||||||
self._logger.debug(f'FWD PRE {module.__class__.__name__}')
|
|
||||||
|
|
||||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
if self._isvalid(module):
|
if self._isvalid(module):
|
||||||
self.async_mem_monitor.finish()
|
self.async_mem_monitor.finish()
|
||||||
self._logger.debug(f'FWD POST {module.__class__.__name__}')
|
|
||||||
|
|
||||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||||
if self._isvalid(module):
|
if self._isvalid(module):
|
||||||
self.async_mem_monitor.finish()
|
self.async_mem_monitor.finish()
|
||||||
self.async_mem_monitor.start()
|
self.async_mem_monitor.start()
|
||||||
self._logger.debug(f'BWD PRE {module.__class__.__name__}')
|
|
||||||
|
|
||||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||||
if self._isvalid(module):
|
if self._isvalid(module):
|
||||||
self.async_mem_monitor.finish()
|
self.async_mem_monitor.finish()
|
||||||
self._logger.debug(f'BWD POST {module.__class__.__name__}')
|
|
||||||
|
|
||||||
def pre_iter(self):
|
def pre_iter(self):
|
||||||
pass
|
pass
|
||||||
@ -170,19 +178,21 @@ class MemTracerOpHook(BaseOpHook):
|
|||||||
def post_iter(self):
|
def post_iter(self):
|
||||||
self.async_mem_monitor.finish()
|
self.async_mem_monitor.finish()
|
||||||
# in the warmup stage
|
# in the warmup stage
|
||||||
if self._curiter < self.warmup:
|
if self.curiter < self.warmup:
|
||||||
# TODO: record time and adaptively change sampling rate
|
|
||||||
pass
|
pass
|
||||||
elif self._curiter == self._warmup:
|
# adjust the sampling rate
|
||||||
self.async_mem_monitor.clear()
|
elif self.curiter == self.warmup:
|
||||||
|
# use adaptive sample rate
|
||||||
|
self._resample()
|
||||||
|
# record data to log file
|
||||||
else:
|
else:
|
||||||
# every `refreshrate` times, refresh the file
|
# every `refreshrate` times, refresh the file
|
||||||
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
||||||
# output file info
|
# output file info
|
||||||
self._logger.info(f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
|
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
||||||
self.save_results()
|
self.save_results()
|
||||||
self._count += 1
|
self._count += 1
|
||||||
self._logger.debug(f'data file has been refreshed {self._count} times')
|
self._logger.debug(f"data file has been refreshed {self._count} times")
|
||||||
# finish a iteration
|
# finish a iteration
|
||||||
self._curiter += 1
|
self._curiter += 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user