mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[zero] refactor memstats collector (#706)
* refactor memstats collector * fix disposable * polish code
This commit is contained in:
@@ -5,7 +5,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
|
||||
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
|
||||
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param)
|
||||
sync_model_param, disposable)
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .gradient_accumulation import accumulate_gradient
|
||||
from .memory_utils.memory_monitor import report_memory_usage
|
||||
@@ -19,5 +19,5 @@ __all__ = [
|
||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
||||
'ensure_path_exists'
|
||||
'ensure_path_exists', 'disposable'
|
||||
]
|
||||
|
@@ -4,8 +4,8 @@ import os
|
||||
import random
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from typing import Callable, List, Union
|
||||
import functools
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
|
||||
|
||||
|
||||
class model_branch_context(object):
|
||||
|
||||
def __enter__(self):
|
||||
self.env_status = env.save()
|
||||
|
||||
@@ -131,7 +132,7 @@ def _calc_l2_norm(grads):
|
||||
colossal_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
@@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
|
||||
yield
|
||||
finally:
|
||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||
|
||||
|
||||
def disposable(func: Callable) -> Callable:
|
||||
executed = False
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal executed
|
||||
if not executed:
|
||||
executed = True
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
@@ -1,36 +1,11 @@
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
|
||||
import torch
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
|
||||
class SamplingCounter:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._samplint_cnt = 0
|
||||
self._max_sampling_cnt = None
|
||||
|
||||
def advance(self):
|
||||
self._samplint_cnt += 1
|
||||
|
||||
def next(self):
|
||||
assert self._max_sampling_cnt is not None
|
||||
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
||||
|
||||
def current(self):
|
||||
return self._samplint_cnt
|
||||
|
||||
def max(self):
|
||||
return self._max_sampling_cnt
|
||||
|
||||
def reset(self):
|
||||
self._max_sampling_cnt = self._samplint_cnt
|
||||
self._samplint_cnt = 0
|
||||
|
||||
|
||||
class MemStatsCollector:
|
||||
"""
|
||||
A Memory statistic collector.
|
||||
@@ -44,7 +19,6 @@ class MemStatsCollector:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sampling_cnter = SamplingCounter()
|
||||
self._mem_monitor = AsyncMemoryMonitor()
|
||||
self._model_data_cuda_list = []
|
||||
self._overall_cuda_list = []
|
||||
@@ -57,6 +31,7 @@ class MemStatsCollector:
|
||||
self._sampling_time = []
|
||||
|
||||
self._start_flag = False
|
||||
self._period_idx = 0
|
||||
|
||||
def overall_mem_stats(self, device_type: str):
|
||||
if device_type == 'cuda':
|
||||
@@ -106,15 +81,22 @@ class MemStatsCollector:
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def current_non_model_data(self, device_type: str) -> int:
|
||||
"""get the non model data of the current sampling moment
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
|
||||
def max_non_model_data(self, device_type: str) -> int:
|
||||
"""Get max non model data memory usage of current sampling period
|
||||
|
||||
def next_non_model_data(self, device_type: str):
|
||||
"""get the non model data of the next sampling moment
|
||||
Args:
|
||||
device_type (str): device type, can be 'cpu' or 'cuda'.
|
||||
|
||||
Returns:
|
||||
int: max non model data memory usage of current sampling period
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||
assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.'
|
||||
next_period_idx = (self._period_idx + 1) % len(self._sampling_time)
|
||||
current_non_model_data = self.non_model_data_list(device_type)[self._period_idx]
|
||||
next_non_model_data = self.non_model_data_list(device_type)[next_period_idx]
|
||||
self._period_idx = next_period_idx
|
||||
return max(current_non_model_data, next_non_model_data)
|
||||
|
||||
@property
|
||||
def sampling_time(self):
|
||||
@@ -126,6 +108,7 @@ class MemStatsCollector:
|
||||
|
||||
def finish_collection(self):
|
||||
self._start_flag = False
|
||||
self._mem_monitor.finish()
|
||||
|
||||
def sample_memstats(self) -> None:
|
||||
"""
|
||||
@@ -134,8 +117,6 @@ class MemStatsCollector:
|
||||
Advance the sampling cnter.
|
||||
"""
|
||||
if self._start_flag:
|
||||
sampling_cnt = self._sampling_cnter.current()
|
||||
assert sampling_cnt == len(self._overall_cuda_list)
|
||||
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
||||
@@ -146,13 +127,6 @@ class MemStatsCollector:
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
self._sampling_time.append(time.time())
|
||||
self._mem_monitor.start()
|
||||
# TODO(ver217): refactor sampler
|
||||
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
def reset_sampling_cnter(self) -> None:
|
||||
self._sampling_cnter.reset()
|
||||
self._mem_monitor.finish()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._model_data_cuda_list = []
|
||||
@@ -162,5 +136,4 @@ class MemStatsCollector:
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._start_flag = False
|
||||
self._sampling_cnter.reset()
|
||||
self._mem_monitor.finish()
|
||||
self._period_idx = 0
|
||||
|
@@ -1,16 +0,0 @@
|
||||
from async_memtracer import AsyncMemoryMonitor
|
||||
import torch
|
||||
|
||||
if __name__ == '__main__':
|
||||
async_mem_monitor = AsyncMemoryMonitor()
|
||||
input = torch.randn(2, 20).cuda()
|
||||
OP1 = torch.nn.Linear(20, 30).cuda()
|
||||
OP2 = torch.nn.Linear(30, 40).cuda()
|
||||
|
||||
async_mem_monitor.start()
|
||||
output = OP1(input)
|
||||
async_mem_monitor.finish()
|
||||
async_mem_monitor.start()
|
||||
output = OP2(output)
|
||||
async_mem_monitor.finish()
|
||||
async_mem_monitor.save('log.pkl')
|
@@ -1,37 +0,0 @@
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
import torch
|
||||
|
||||
|
||||
def test_mem_collector():
|
||||
collector = MemStatsCollector()
|
||||
|
||||
collector.start_collection()
|
||||
|
||||
a = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 0
|
||||
collector.sample_memstats()
|
||||
|
||||
m_a = torch.randn(10).cuda()
|
||||
b = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 1
|
||||
collector.sample_memstats()
|
||||
|
||||
a = b
|
||||
|
||||
# sampling at time 2
|
||||
collector.sample_memstats()
|
||||
|
||||
collector.finish_collection()
|
||||
collector.reset_sampling_cnter()
|
||||
|
||||
# do nothing after collection, just advance sampling cnter
|
||||
collector.sample_memstats()
|
||||
collector.sample_memstats()
|
||||
|
||||
print(collector.overall_mem_stats('cuda'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mem_collector()
|
Reference in New Issue
Block a user