mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[refactor] moving memtracer to gemini (#801)
This commit is contained in:
@@ -1,4 +0,0 @@
|
||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
|
||||
from .memstats_collector import MemStatsCollector
|
||||
|
||||
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector']
|
@@ -1,142 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from time import sleep, time
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
"""Base class for all types of memory monitor.
|
||||
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.time_stamps = []
|
||||
self.mem_stats = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mem_stats)
|
||||
|
||||
@abstractmethod
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def finish(self):
|
||||
pass
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"time_stamps": self.time_stamps,
|
||||
"mem_stats": self.mem_stats,
|
||||
}
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, "w") as f:
|
||||
json.dump(self.state_dict(), f)
|
||||
|
||||
def clear(self):
|
||||
self.mem_stats.clear()
|
||||
self.time_stamps.clear()
|
||||
|
||||
|
||||
class AsyncMemoryMonitor(MemoryMonitor):
|
||||
"""
|
||||
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
|
||||
at interval of `1/(10**power)` sec.
|
||||
|
||||
The idea comes from Runtime Memory Tracer of PatrickStar
|
||||
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
|
||||
|
||||
Usage::
|
||||
|
||||
async_mem_monitor = AsyncMemoryMonitor()
|
||||
input = torch.randn(2, 20).cuda()
|
||||
OP1 = torch.nn.Linear(20, 30).cuda()
|
||||
OP2 = torch.nn.Linear(30, 40).cuda()
|
||||
|
||||
async_mem_monitor.start()
|
||||
output = OP1(input)
|
||||
async_mem_monitor.finish()
|
||||
async_mem_monitor.start()
|
||||
output = OP2(output)
|
||||
async_mem_monitor.finish()
|
||||
async_mem_monitor.save('log.pkl')
|
||||
|
||||
Args:
|
||||
power (int, optional): the power of time interva. Defaults to 10.
|
||||
|
||||
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, power: int = 10):
|
||||
super().__init__()
|
||||
self.keep_measuring = False
|
||||
|
||||
current_device = get_current_device()
|
||||
|
||||
def _set_cuda_device():
|
||||
torch.cuda.set_device(current_device)
|
||||
|
||||
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
|
||||
self.monitor_thread = None
|
||||
self.interval = 1 / (10**power)
|
||||
|
||||
def set_interval(self, power: int):
|
||||
self.clear()
|
||||
self.interval = 1 / (10**power)
|
||||
|
||||
def is_measuring(self):
|
||||
return self.keep_measuring
|
||||
|
||||
def start(self):
|
||||
self.keep_measuring = True
|
||||
self.monitor_thread = self.executor.submit(self._measure_usage)
|
||||
|
||||
def finish(self):
|
||||
if self.keep_measuring is False:
|
||||
return 0
|
||||
|
||||
self.keep_measuring = False
|
||||
max_usage = self.monitor_thread.result()
|
||||
|
||||
self.monitor_thread = None
|
||||
self.time_stamps.append(time())
|
||||
self.mem_stats.append(max_usage)
|
||||
return max_usage
|
||||
|
||||
def _measure_usage(self):
|
||||
max_usage = 0
|
||||
while self.keep_measuring:
|
||||
max_usage = max(
|
||||
max_usage,
|
||||
colo_device_memory_used(get_current_device()),
|
||||
)
|
||||
sleep(self.interval)
|
||||
return max_usage
|
||||
|
||||
|
||||
class SyncCudaMemoryMonitor(MemoryMonitor):
|
||||
"""
|
||||
A synchronized cuda memory monitor.
|
||||
It only record the maximum allocated cuda memory from start point to finish point.
|
||||
"""
|
||||
|
||||
def __init__(self, power: int = 10):
|
||||
super().__init__()
|
||||
|
||||
def start(self):
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def finish(self):
|
||||
torch.cuda.synchronize()
|
||||
self.time_stamps.append(time())
|
||||
max_usage = torch.cuda.max_memory_allocated()
|
||||
self.mem_stats.append(max_usage)
|
||||
return max_usage
|
@@ -1,143 +0,0 @@
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.utils.memory_tracer import SyncCudaMemoryMonitor
|
||||
import torch
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
|
||||
class MemStatsCollector:
|
||||
"""
|
||||
A Memory statistic collector.
|
||||
It works in two phases.
|
||||
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
|
||||
The first iteration of DNN training.
|
||||
Phase 2. Runtime Phase: use the read-only collected stats
|
||||
The rest iterations of DNN training.
|
||||
|
||||
It has a Sampling counter which is reset after DNN training iteration.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._mem_monitor = SyncCudaMemoryMonitor()
|
||||
self._model_data_cuda_list = []
|
||||
self._overall_cuda_list = []
|
||||
|
||||
self._model_data_cpu_list = []
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._non_model_data_cuda_list = []
|
||||
self._non_model_data_cpu_list = []
|
||||
self._sampling_time = []
|
||||
|
||||
self._start_flag = False
|
||||
self._step_idx = 0
|
||||
self._step_total = 0
|
||||
|
||||
def overall_mem_stats(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._overall_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._overall_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._non_model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._non_model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
||||
"""Get max non model data memory usage of current sampling period
|
||||
|
||||
Args:
|
||||
device_type (str): device type, can be 'cpu' or 'cuda'.
|
||||
|
||||
Returns:
|
||||
int: max non model data memory usage of current sampling period
|
||||
"""
|
||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
||||
next_non_model_data = self.non_model_data_list(device_type)[self._step_idx]
|
||||
self._step_idx = (self._step_idx + 1) % self._step_total
|
||||
return next_non_model_data
|
||||
|
||||
@property
|
||||
def sampling_time(self):
|
||||
return [t - self._sampling_time[0] for t in self._sampling_time]
|
||||
|
||||
def start_collection(self):
|
||||
self._start_flag = True
|
||||
self._mem_monitor.start()
|
||||
|
||||
def finish_collection(self):
|
||||
self.sample_overall_data()
|
||||
self._step_total = len(self._sampling_time)
|
||||
self._start_flag = False
|
||||
self._mem_monitor.finish()
|
||||
|
||||
def sample_model_data(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag:
|
||||
cuda_mem, cpu_mem = GLOBAL_MODEL_DATA_TRACER.both_mem_usage
|
||||
self._model_data_cuda_list.append(cuda_mem)
|
||||
self._model_data_cpu_list.append(cpu_mem)
|
||||
|
||||
def sample_overall_data(self) -> None:
|
||||
"""Sampling non model data statistics.
|
||||
"""
|
||||
if self._start_flag:
|
||||
# overall data recording is after model data recording
|
||||
if len(self._model_data_cuda_list) == 0:
|
||||
return
|
||||
|
||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||
self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu')))
|
||||
|
||||
assert len(self._model_data_cuda_list) == len(self._overall_cuda_list)
|
||||
|
||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
self._sampling_time.append(time.time())
|
||||
self._mem_monitor.start()
|
||||
|
||||
def sample_memstats(self) -> None:
|
||||
"""
|
||||
Sampling memory statistics.
|
||||
Record the current model data CUDA memory usage as well as system CUDA memory usage.
|
||||
Advance the sampling cnter.
|
||||
"""
|
||||
if self._start_flag:
|
||||
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||
|
||||
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
||||
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
|
||||
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
self._sampling_time.append(time.time())
|
||||
self._mem_monitor.start()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._model_data_cuda_list = []
|
||||
self._overall_cuda_list = []
|
||||
|
||||
self._model_data_cpu_list = []
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._start_flag = False
|
||||
self._step_idx = 0
|
||||
self._step_total = 0
|
@@ -1,109 +0,0 @@
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
|
||||
def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
|
||||
"""Trace the optimizer memory usage
|
||||
|
||||
Args:
|
||||
optim (ShardedOptimV2): an instance of ShardedOptimver
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: cuda/cpu memory usage in Byte
|
||||
"""
|
||||
if optim is None:
|
||||
return 0, 0
|
||||
assert hasattr(optim, 'get_memory_usage'), f"{type(optim)} has no attr get_memory_usage()"
|
||||
return optim.get_memory_usage()
|
||||
|
||||
|
||||
def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
"""
|
||||
Trace the model memory usage.
|
||||
Args:
|
||||
model (torch.nn.Module): a torch model
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: cuda memory usage in Byte, cpu memory usage in Byte
|
||||
"""
|
||||
if model is None:
|
||||
return 0, 0
|
||||
|
||||
def _get_tensor_mem_use(t: Optional[torch.Tensor]):
|
||||
if t is None:
|
||||
return 0, 0
|
||||
assert isinstance(t, torch.Tensor)
|
||||
_cpu_mem_usage, _cuda_mem_usage = 0, 0
|
||||
if t.device.type == 'cpu':
|
||||
_cpu_mem_usage += t.numel() * t.element_size()
|
||||
elif t.device.type == 'cuda':
|
||||
_cuda_mem_usage += t.numel() * t.element_size()
|
||||
return _cuda_mem_usage, _cpu_mem_usage
|
||||
|
||||
cuda_mem_usage = 0
|
||||
cpu_mem_usage = 0
|
||||
for param in model.parameters():
|
||||
if hasattr(param, 'colo_attr'):
|
||||
t_cuda, t_cpu = param.colo_attr.get_memory_usage()
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
else:
|
||||
t_cuda, t_cpu = _get_tensor_mem_use(param.data)
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
t_cuda, t_cpu = _get_tensor_mem_use(param.grad)
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
|
||||
return cuda_mem_usage, cpu_mem_usage
|
||||
|
||||
|
||||
class ModelDataTracer(metaclass=SingletonMeta):
|
||||
"""
|
||||
A tracer singleton to trace model data usage during runtime.
|
||||
You have to register a model on the singleton first.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._logger = DistributedLogger("ModelDataTracer")
|
||||
self._model = None
|
||||
self._opitimizer = None
|
||||
|
||||
def _get_mem_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
get the memory usage of the model registered.
|
||||
Returns:
|
||||
Tuple[int, int]: cuda, cpu mem usage
|
||||
"""
|
||||
cuda_use_opt, cpu_use_opt = colo_model_optimizer_usage(self._opitimizer)
|
||||
cuda_use_model, cpu_use_model = colo_model_mem_usage(self._model)
|
||||
return cuda_use_opt + cuda_use_model, cpu_use_opt + cpu_use_model
|
||||
|
||||
def register_model(self, model) -> None:
|
||||
if self._model is not None:
|
||||
self._logger.warning("ModelDataTracer has already registered a model")
|
||||
self._model = model
|
||||
|
||||
def register_optimizer(self, optimizer) -> None:
|
||||
if self._opitimizer is not None:
|
||||
self._logger.warning("ModelDataTracer has already registered an optimizer")
|
||||
self._opitimizer = optimizer
|
||||
|
||||
@property
|
||||
def cpu_usage(self):
|
||||
_, cpu_usage = self._get_mem_usage()
|
||||
return cpu_usage
|
||||
|
||||
@property
|
||||
def cuda_usage(self):
|
||||
cuda_usage, _ = self._get_mem_usage()
|
||||
return cuda_usage
|
||||
|
||||
@property
|
||||
def both_mem_usage(self):
|
||||
return self._get_mem_usage()
|
||||
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
Reference in New Issue
Block a user