[Gemini] independent runtime tracer (#1974)

This commit is contained in:
Jiarui Fang 2022-11-18 10:53:42 +08:00 committed by GitHub
parent 0da1d00399
commit 0529fcde06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 271 additions and 143 deletions

View File

@ -3,8 +3,9 @@ from .memstats_collector import MemStatsCollector # isort:skip
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .static_memstats_collector import StaticMemStatsCollector # isort:skip from .static_memstats_collector import StaticMemStatsCollector # isort:skip
from .module_tracer_wrapper import MemtracerWrapper # isort:skip
__all__ = [ __all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER' 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper'
] ]

View File

@ -1,142 +1,147 @@
from abc import abstractmethod import json
from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod
from time import sleep, time from concurrent.futures import ThreadPoolExecutor
import json from time import sleep, time
import torch import torch
from colossalai.utils import colo_device_memory_used from colossalai.utils import colo_device_memory_used, get_current_device
from colossalai.utils import get_current_device
class MemoryMonitor:
class MemoryMonitor: """Base class for all types of memory monitor.
"""Base class for all types of memory monitor. All monitors should have a list called `time_stamps` and a list called `mem_stats`.
All monitors should have a list called `time_stamps` and a list called `mem_stats`. """
"""
def __init__(self):
def __init__(self): self.time_stamps = []
self.time_stamps = [] self.mem_stats = []
self.mem_stats = []
def __len__(self):
def __len__(self): return len(self.mem_stats)
return len(self.mem_stats)
@abstractmethod
@abstractmethod def start(self):
def start(self): pass
pass
@abstractmethod
@abstractmethod def finish(self):
def finish(self): pass
pass
def state_dict(self):
def state_dict(self): return {
return { "time_stamps": self.time_stamps,
"time_stamps": self.time_stamps, "mem_stats": self.mem_stats,
"mem_stats": self.mem_stats, }
}
def save(self, filename):
def save(self, filename): with open(filename, "w") as f:
with open(filename, "w") as f: json.dump(self.state_dict(), f)
json.dump(self.state_dict(), f)
def clear(self):
def clear(self): self.mem_stats.clear()
self.mem_stats.clear() self.time_stamps.clear()
self.time_stamps.clear()
class AsyncMemoryMonitor(MemoryMonitor):
class AsyncMemoryMonitor(MemoryMonitor): """
""" An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU at interval of `1/(10**power)` sec.
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
The idea comes from Runtime Memory Tracer of PatrickStar `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
Usage::
async_mem_monitor = AsyncMemoryMonitor()
async_mem_monitor = AsyncMemoryMonitor() input = torch.randn(2, 20).cuda()
input = torch.randn(2, 20).cuda() OP1 = torch.nn.Linear(20, 30).cuda()
OP1 = torch.nn.Linear(20, 30).cuda() OP2 = torch.nn.Linear(30, 40).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
async_mem_monitor.start() output = OP1(input)
output = OP1(input) async_mem_monitor.finish()
async_mem_monitor.finish() async_mem_monitor.start()
async_mem_monitor.start() output = OP2(output)
output = OP2(output) async_mem_monitor.finish()
async_mem_monitor.finish() async_mem_monitor.save('log.pkl')
async_mem_monitor.save('log.pkl')
Args:
Args: power (int, optional): the power of time interva. Defaults to 10.
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: https://arxiv.org/abs/2108.05818
https://arxiv.org/abs/2108.05818 """
"""
def __init__(self, power: int = 10):
def __init__(self, power: int = 10): super().__init__()
super().__init__() self.keep_measuring = False
self.keep_measuring = False
current_device = get_current_device()
current_device = get_current_device()
def _set_cuda_device():
def _set_cuda_device(): torch.cuda.set_device(current_device)
torch.cuda.set_device(current_device)
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) self.monitor_thread = None
self.monitor_thread = None self.interval = 1 / (10**power)
self.interval = 1 / (10**power)
def set_interval(self, power: int):
def set_interval(self, power: int): self.clear()
self.clear() self.interval = 1 / (10**power)
self.interval = 1 / (10**power)
def is_measuring(self):
def is_measuring(self): return self.keep_measuring
return self.keep_measuring
def start(self):
def start(self): self.keep_measuring = True
self.keep_measuring = True self.monitor_thread = self.executor.submit(self._measure_usage)
self.monitor_thread = self.executor.submit(self._measure_usage)
def finish(self):
def finish(self): if self.keep_measuring is False:
if self.keep_measuring is False: return 0
return 0
self.keep_measuring = False
self.keep_measuring = False max_usage = self.monitor_thread.result()
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.monitor_thread = None self.time_stamps.append(time())
self.time_stamps.append(time()) self.mem_stats.append(max_usage)
self.mem_stats.append(max_usage) return max_usage
return max_usage
def _measure_usage(self):
def _measure_usage(self): max_usage = 0
max_usage = 0 while self.keep_measuring:
while self.keep_measuring: max_usage = max(
max_usage = max( max_usage,
max_usage, colo_device_memory_used(get_current_device()),
colo_device_memory_used(get_current_device()), )
) sleep(self.interval)
sleep(self.interval) return max_usage
return max_usage
class SyncCudaMemoryMonitor(MemoryMonitor):
class SyncCudaMemoryMonitor(MemoryMonitor): """
""" A synchronized cuda memory monitor.
A synchronized cuda memory monitor. It only record the maximum allocated cuda memory from start point to finish point.
It only record the maximum allocated cuda memory from start point to finish point. """
"""
def __init__(self, power: int = 10):
def __init__(self, power: int = 10): super().__init__()
super().__init__()
def start(self):
def start(self): torch.cuda.synchronize()
torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_peak_memory_stats()
def finish(self) -> int:
def finish(self): """
torch.cuda.synchronize() return max gpu memory used since latest `start()`.
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated() Returns:
self.mem_stats.append(max_usage) int: max GPU memory
return max_usage """
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
self.mem_stats.append(max_usage)
return max_usage

View File

@ -0,0 +1,36 @@
from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.gemini.ophooks.mem_trace_hook import MemTracerOpHook
__all__ = ['MemtracerWrapper']
class _Wrapper():
def __init__(self, model, ophook_list):
self._ophook_list = ophook_list
self._model = model
def __call__(self, *args, **kwargs):
return self._model(*args, **kwargs)
def forward(self, *args, **kwargs):
return self._model.forward(*args, **kwargs)
def backward(self, loss):
loss.backward()
for ophook in self._ophook_list:
ophook.post_iter()
def save_results(self, filename):
for ophook in self._ophook_list:
ophook.save_results(filename)
def show_mem_stats(self):
self._ophook_list[0].show_mem_stats()
def MemtracerWrapper(model):
ophook_list = [MemTracerOpHook()]
register_ophooks_recursively(model, ophook_list)
engine = _Wrapper(model, ophook_list)
return engine

View File

@ -0,0 +1,86 @@
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.ophooks import BaseOpHook
class MemTracerOpHook(BaseOpHook):
def __init__(self):
super().__init__()
self.mem_monitor = SyncCudaMemoryMonitor()
self._cur_non_model_data_vol = 0
self._non_model_data_list = []
self._cur_model_data_vol = 0
def _move_module_to_dev(self, module, dev: str) -> int:
"""_move_module_to_dev
move module to cuda
Args:
module (torch.nn.Module): a PyTorch module
dev (torch.device): the target device
Returns:
int: the data volume of this module on the cuda
"""
assert isinstance(dev, str), f"device should be a str not torch.device"
comm_volume = 0
for p in module.parameters():
if p.data.device.type != dev:
p.data = p.data.to(dev)
comm_volume += p.data.numel() * p.data.element_size()
if p.grad is not None:
if p.grad.device.type != dev:
p.grad = p.grad.to(dev)
comm_volume += p.grad.numel() * p.grad.element_size()
if dev == 'cuda':
self._cur_model_data_vol = comm_volume
return comm_volume
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if module.training:
cuda_volume = self.mem_monitor.finish()
comm_volume = self._move_module_to_dev(module, 'cuda')
self.mem_monitor.start()
# print(f'FWD PRE {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB')
def post_fwd_exec(self, module: torch.nn.Module, *args):
if module.training:
cuda_volume = self.mem_monitor.finish()
comm_volume = self._move_module_to_dev(module, 'cpu')
# print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
assert isinstance(module, torch.nn.Module)
if module.training:
cuda_volume = self.mem_monitor.finish()
self._move_module_to_dev(module, 'cuda')
self.mem_monitor.start()
# print(f'BWD PRE {module.__class__.__name__}')
def post_bwd_exec(self, module: torch.nn.Module, input):
# bwd Op will generate grad. comm_volume is grad + data volume on cuda.
assert isinstance(module, torch.nn.Module)
if module.training:
cuda_volume = self.mem_monitor.finish()
comm_volume = self._move_module_to_dev(module, 'cpu')
# print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
def pre_iter(self):
pass
def post_iter(self):
self.mem_monitor.finish()
# print(f'post_iter')
def save_results(self, filename):
self.mem_monitor.save(filename)
def show_mem_stats(self):
start_timestamp = min(self.mem_monitor.time_stamps)
self.mem_monitor.time_stamps = [elem - start_timestamp for elem in self.mem_monitor.time_stamps]
min_mem_used = min(self.mem_monitor.mem_stats)
self.mem_monitor.mem_stats = [elem - min_mem_used for elem in self.mem_monitor.mem_stats]
print(self.mem_monitor.time_stamps)
print(self.mem_monitor.mem_stats)