mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 19:55:29 +00:00
[zero] refactor memstats collector (#706)
* refactor memstats collector * fix disposable * polish code
This commit is contained in:
parent
3fc8a204dc
commit
ab8c6b4a0e
@ -5,7 +5,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
|
||||
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
|
||||
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
|
||||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||
sync_model_param)
|
||||
sync_model_param, disposable)
|
||||
from .data_sampler import DataParallelSampler, get_dataloader
|
||||
from .gradient_accumulation import accumulate_gradient
|
||||
from .memory_utils.memory_monitor import report_memory_usage
|
||||
@ -19,5 +19,5 @@ __all__ = [
|
||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
||||
'ensure_path_exists'
|
||||
'ensure_path_exists', 'disposable'
|
||||
]
|
||||
|
@ -4,8 +4,8 @@ import os
|
||||
import random
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from typing import Callable, List, Union
|
||||
import functools
|
||||
import torch
|
||||
from torch._six import inf
|
||||
from torch.nn.parameter import Parameter
|
||||
@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
|
||||
|
||||
|
||||
class model_branch_context(object):
|
||||
|
||||
def __enter__(self):
|
||||
self.env_status = env.save()
|
||||
|
||||
@ -131,7 +132,7 @@ def _calc_l2_norm(grads):
|
||||
colossal_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False # no per-parameter norm
|
||||
False # no per-parameter norm
|
||||
)
|
||||
return norm
|
||||
|
||||
@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
|
||||
yield
|
||||
finally:
|
||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||
|
||||
|
||||
def disposable(func: Callable) -> Callable:
|
||||
executed = False
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
nonlocal executed
|
||||
if not executed:
|
||||
executed = True
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
@ -1,36 +1,11 @@
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
|
||||
import torch
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
|
||||
class SamplingCounter:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._samplint_cnt = 0
|
||||
self._max_sampling_cnt = None
|
||||
|
||||
def advance(self):
|
||||
self._samplint_cnt += 1
|
||||
|
||||
def next(self):
|
||||
assert self._max_sampling_cnt is not None
|
||||
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
||||
|
||||
def current(self):
|
||||
return self._samplint_cnt
|
||||
|
||||
def max(self):
|
||||
return self._max_sampling_cnt
|
||||
|
||||
def reset(self):
|
||||
self._max_sampling_cnt = self._samplint_cnt
|
||||
self._samplint_cnt = 0
|
||||
|
||||
|
||||
class MemStatsCollector:
|
||||
"""
|
||||
A Memory statistic collector.
|
||||
@ -44,7 +19,6 @@ class MemStatsCollector:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sampling_cnter = SamplingCounter()
|
||||
self._mem_monitor = AsyncMemoryMonitor()
|
||||
self._model_data_cuda_list = []
|
||||
self._overall_cuda_list = []
|
||||
@ -57,6 +31,7 @@ class MemStatsCollector:
|
||||
self._sampling_time = []
|
||||
|
||||
self._start_flag = False
|
||||
self._period_idx = 0
|
||||
|
||||
def overall_mem_stats(self, device_type: str):
|
||||
if device_type == 'cuda':
|
||||
@ -106,15 +81,22 @@ class MemStatsCollector:
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def current_non_model_data(self, device_type: str) -> int:
|
||||
"""get the non model data of the current sampling moment
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
|
||||
def max_non_model_data(self, device_type: str) -> int:
|
||||
"""Get max non model data memory usage of current sampling period
|
||||
|
||||
def next_non_model_data(self, device_type: str):
|
||||
"""get the non model data of the next sampling moment
|
||||
Args:
|
||||
device_type (str): device type, can be 'cpu' or 'cuda'.
|
||||
|
||||
Returns:
|
||||
int: max non model data memory usage of current sampling period
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
||||
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||
assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.'
|
||||
next_period_idx = (self._period_idx + 1) % len(self._sampling_time)
|
||||
current_non_model_data = self.non_model_data_list(device_type)[self._period_idx]
|
||||
next_non_model_data = self.non_model_data_list(device_type)[next_period_idx]
|
||||
self._period_idx = next_period_idx
|
||||
return max(current_non_model_data, next_non_model_data)
|
||||
|
||||
@property
|
||||
def sampling_time(self):
|
||||
@ -126,6 +108,7 @@ class MemStatsCollector:
|
||||
|
||||
def finish_collection(self):
|
||||
self._start_flag = False
|
||||
self._mem_monitor.finish()
|
||||
|
||||
def sample_memstats(self) -> None:
|
||||
"""
|
||||
@ -134,8 +117,6 @@ class MemStatsCollector:
|
||||
Advance the sampling cnter.
|
||||
"""
|
||||
if self._start_flag:
|
||||
sampling_cnt = self._sampling_cnter.current()
|
||||
assert sampling_cnt == len(self._overall_cuda_list)
|
||||
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
||||
@ -146,13 +127,6 @@ class MemStatsCollector:
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
self._sampling_time.append(time.time())
|
||||
self._mem_monitor.start()
|
||||
# TODO(ver217): refactor sampler
|
||||
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
def reset_sampling_cnter(self) -> None:
|
||||
self._sampling_cnter.reset()
|
||||
self._mem_monitor.finish()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._model_data_cuda_list = []
|
||||
@ -162,5 +136,4 @@ class MemStatsCollector:
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._start_flag = False
|
||||
self._sampling_cnter.reset()
|
||||
self._mem_monitor.finish()
|
||||
self._period_idx = 0
|
||||
|
@ -1,16 +0,0 @@
|
||||
from async_memtracer import AsyncMemoryMonitor
|
||||
import torch
|
||||
|
||||
if __name__ == '__main__':
|
||||
async_mem_monitor = AsyncMemoryMonitor()
|
||||
input = torch.randn(2, 20).cuda()
|
||||
OP1 = torch.nn.Linear(20, 30).cuda()
|
||||
OP2 = torch.nn.Linear(30, 40).cuda()
|
||||
|
||||
async_mem_monitor.start()
|
||||
output = OP1(input)
|
||||
async_mem_monitor.finish()
|
||||
async_mem_monitor.start()
|
||||
output = OP2(output)
|
||||
async_mem_monitor.finish()
|
||||
async_mem_monitor.save('log.pkl')
|
@ -1,37 +0,0 @@
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
import torch
|
||||
|
||||
|
||||
def test_mem_collector():
|
||||
collector = MemStatsCollector()
|
||||
|
||||
collector.start_collection()
|
||||
|
||||
a = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 0
|
||||
collector.sample_memstats()
|
||||
|
||||
m_a = torch.randn(10).cuda()
|
||||
b = torch.randn(10).cuda()
|
||||
|
||||
# sampling at time 1
|
||||
collector.sample_memstats()
|
||||
|
||||
a = b
|
||||
|
||||
# sampling at time 2
|
||||
collector.sample_memstats()
|
||||
|
||||
collector.finish_collection()
|
||||
collector.reset_sampling_cnter()
|
||||
|
||||
# do nothing after collection, just advance sampling cnter
|
||||
collector.sample_memstats()
|
||||
collector.sample_memstats()
|
||||
|
||||
print(collector.overall_mem_stats('cuda'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_mem_collector()
|
@ -71,8 +71,7 @@ class StatefulTensorMgr(object):
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
|
||||
self._mem_stats_collector.next_non_model_data('cuda'))
|
||||
max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda')
|
||||
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
|
@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.engine.gradient_handler.utils import bucket_allreduce
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module):
|
||||
for param in submodule.parameters(recurse=False):
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
self._stateful_tensor_mgr = None
|
||||
self._iter_cnter = 0
|
||||
|
||||
# Register hooks
|
||||
self._ophook_list = [
|
||||
@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module):
|
||||
f.write('\n')
|
||||
|
||||
def _pre_forward_operations(self):
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
self._memstats_collector.start_collection()
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
if self._memstats_collector:
|
||||
self._start_collect_memstats()
|
||||
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'colo_attr'):
|
||||
@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module):
|
||||
ophook.post_iter()
|
||||
|
||||
def _update_memstats(self):
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
self._memstats_collector.finish_collection()
|
||||
if self._memstats_collector:
|
||||
self._memstats_collector.reset_sampling_cnter()
|
||||
self._finish_collect_memstats()
|
||||
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
||||
# the way to calculate margin space is based on the assumption that
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
|
||||
self._memstats_collector.overall_mem_stats('cuda'))
|
||||
self._iter_cnter += 1
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
|
@ -55,7 +55,6 @@ def run_stm():
|
||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
mem_collector.finish_collection()
|
||||
mem_collector.reset_sampling_cnter()
|
||||
stateful_tensor_mgr.reset()
|
||||
|
||||
# warmup done
|
||||
|
Loading…
Reference in New Issue
Block a user