[Gemini] independent runtime tracer (#1974)

This commit is contained in:
Jiarui Fang
2022-11-18 10:53:42 +08:00
committed by GitHub
parent 0da1d00399
commit 0529fcde06
4 changed files with 271 additions and 143 deletions

View File

@@ -3,8 +3,9 @@ from .memstats_collector import MemStatsCollector # isort:skip
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .static_memstats_collector import StaticMemStatsCollector # isort:skip from .static_memstats_collector import StaticMemStatsCollector # isort:skip
from .module_tracer_wrapper import MemtracerWrapper # isort:skip
__all__ = [ __all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER' 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper'
] ]

View File

@@ -1,12 +1,11 @@
import json
from abc import abstractmethod from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from time import sleep, time from time import sleep, time
import json
import torch import torch
from colossalai.utils import colo_device_memory_used from colossalai.utils import colo_device_memory_used, get_current_device
from colossalai.utils import get_current_device
class MemoryMonitor: class MemoryMonitor:
@@ -134,7 +133,13 @@ class SyncCudaMemoryMonitor(MemoryMonitor):
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
def finish(self): def finish(self) -> int:
"""
return max gpu memory used since latest `start()`.
Returns:
int: max GPU memory
"""
torch.cuda.synchronize() torch.cuda.synchronize()
self.time_stamps.append(time()) self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated() max_usage = torch.cuda.max_memory_allocated()

View File

@@ -0,0 +1,36 @@
from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.gemini.ophooks.mem_trace_hook import MemTracerOpHook
__all__ = ['MemtracerWrapper']
class _Wrapper():
def __init__(self, model, ophook_list):
self._ophook_list = ophook_list
self._model = model
def __call__(self, *args, **kwargs):
return self._model(*args, **kwargs)
def forward(self, *args, **kwargs):
return self._model.forward(*args, **kwargs)
def backward(self, loss):
loss.backward()
for ophook in self._ophook_list:
ophook.post_iter()
def save_results(self, filename):
for ophook in self._ophook_list:
ophook.save_results(filename)
def show_mem_stats(self):
self._ophook_list[0].show_mem_stats()
def MemtracerWrapper(model):
ophook_list = [MemTracerOpHook()]
register_ophooks_recursively(model, ophook_list)
engine = _Wrapper(model, ophook_list)
return engine

View File

@@ -0,0 +1,86 @@
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.ophooks import BaseOpHook
class MemTracerOpHook(BaseOpHook):
def __init__(self):
super().__init__()
self.mem_monitor = SyncCudaMemoryMonitor()
self._cur_non_model_data_vol = 0
self._non_model_data_list = []
self._cur_model_data_vol = 0
def _move_module_to_dev(self, module, dev: str) -> int:
"""_move_module_to_dev
move module to cuda
Args:
module (torch.nn.Module): a PyTorch module
dev (torch.device): the target device
Returns:
int: the data volume of this module on the cuda
"""
assert isinstance(dev, str), f"device should be a str not torch.device"
comm_volume = 0
for p in module.parameters():
if p.data.device.type != dev:
p.data = p.data.to(dev)
comm_volume += p.data.numel() * p.data.element_size()
if p.grad is not None:
if p.grad.device.type != dev:
p.grad = p.grad.to(dev)
comm_volume += p.grad.numel() * p.grad.element_size()
if dev == 'cuda':
self._cur_model_data_vol = comm_volume
return comm_volume
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if module.training:
cuda_volume = self.mem_monitor.finish()
comm_volume = self._move_module_to_dev(module, 'cuda')
self.mem_monitor.start()
# print(f'FWD PRE {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB')
def post_fwd_exec(self, module: torch.nn.Module, *args):
if module.training:
cuda_volume = self.mem_monitor.finish()
comm_volume = self._move_module_to_dev(module, 'cpu')
# print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
assert isinstance(module, torch.nn.Module)
if module.training:
cuda_volume = self.mem_monitor.finish()
self._move_module_to_dev(module, 'cuda')
self.mem_monitor.start()
# print(f'BWD PRE {module.__class__.__name__}')
def post_bwd_exec(self, module: torch.nn.Module, input):
# bwd Op will generate grad. comm_volume is grad + data volume on cuda.
assert isinstance(module, torch.nn.Module)
if module.training:
cuda_volume = self.mem_monitor.finish()
comm_volume = self._move_module_to_dev(module, 'cpu')
# print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def pre_iter(self):
pass
def post_iter(self):
self.mem_monitor.finish()
# print(f'post_iter')
def save_results(self, filename):
self.mem_monitor.save(filename)
def show_mem_stats(self):
start_timestamp = min(self.mem_monitor.time_stamps)
self.mem_monitor.time_stamps = [elem - start_timestamp for elem in self.mem_monitor.time_stamps]
min_mem_used = min(self.mem_monitor.mem_stats)
self.mem_monitor.mem_stats = [elem - min_mem_used for elem in self.mem_monitor.mem_stats]
print(self.mem_monitor.time_stamps)
print(self.mem_monitor.mem_stats)