mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-29 04:05:35 +00:00
[zero] refactor memstats collector (#706)
* refactor memstats collector * fix disposable * polish code
This commit is contained in:
parent
3fc8a204dc
commit
ab8c6b4a0e
@ -5,7 +5,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
|
|||||||
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
|
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
|
||||||
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
|
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
|
||||||
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
|
||||||
sync_model_param)
|
sync_model_param, disposable)
|
||||||
from .data_sampler import DataParallelSampler, get_dataloader
|
from .data_sampler import DataParallelSampler, get_dataloader
|
||||||
from .gradient_accumulation import accumulate_gradient
|
from .gradient_accumulation import accumulate_gradient
|
||||||
from .memory_utils.memory_monitor import report_memory_usage
|
from .memory_utils.memory_monitor import report_memory_usage
|
||||||
@ -19,5 +19,5 @@ __all__ = [
|
|||||||
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
|
||||||
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
|
||||||
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
|
||||||
'ensure_path_exists'
|
'ensure_path_exists', 'disposable'
|
||||||
]
|
]
|
||||||
|
@ -4,8 +4,8 @@ import os
|
|||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import Callable, List, Union
|
||||||
|
import functools
|
||||||
import torch
|
import torch
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
|
|||||||
|
|
||||||
|
|
||||||
class model_branch_context(object):
|
class model_branch_context(object):
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.env_status = env.save()
|
self.env_status = env.save()
|
||||||
|
|
||||||
@ -131,7 +132,7 @@ def _calc_l2_norm(grads):
|
|||||||
colossal_C.multi_tensor_l2norm,
|
colossal_C.multi_tensor_l2norm,
|
||||||
dummy_overflow_buf,
|
dummy_overflow_buf,
|
||||||
[grads],
|
[grads],
|
||||||
False # no per-parameter norm
|
False # no per-parameter norm
|
||||||
)
|
)
|
||||||
return norm
|
return norm
|
||||||
|
|
||||||
@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
||||||
|
|
||||||
|
|
||||||
|
def disposable(func: Callable) -> Callable:
|
||||||
|
executed = False
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
nonlocal executed
|
||||||
|
if not executed:
|
||||||
|
executed = True
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
@ -1,36 +1,11 @@
|
|||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
||||||
from colossalai.utils import get_current_device
|
|
||||||
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
|
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class SamplingCounter:
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._samplint_cnt = 0
|
|
||||||
self._max_sampling_cnt = None
|
|
||||||
|
|
||||||
def advance(self):
|
|
||||||
self._samplint_cnt += 1
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
assert self._max_sampling_cnt is not None
|
|
||||||
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
|
||||||
|
|
||||||
def current(self):
|
|
||||||
return self._samplint_cnt
|
|
||||||
|
|
||||||
def max(self):
|
|
||||||
return self._max_sampling_cnt
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._max_sampling_cnt = self._samplint_cnt
|
|
||||||
self._samplint_cnt = 0
|
|
||||||
|
|
||||||
|
|
||||||
class MemStatsCollector:
|
class MemStatsCollector:
|
||||||
"""
|
"""
|
||||||
A Memory statistic collector.
|
A Memory statistic collector.
|
||||||
@ -44,7 +19,6 @@ class MemStatsCollector:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._sampling_cnter = SamplingCounter()
|
|
||||||
self._mem_monitor = AsyncMemoryMonitor()
|
self._mem_monitor = AsyncMemoryMonitor()
|
||||||
self._model_data_cuda_list = []
|
self._model_data_cuda_list = []
|
||||||
self._overall_cuda_list = []
|
self._overall_cuda_list = []
|
||||||
@ -57,6 +31,7 @@ class MemStatsCollector:
|
|||||||
self._sampling_time = []
|
self._sampling_time = []
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
|
self._period_idx = 0
|
||||||
|
|
||||||
def overall_mem_stats(self, device_type: str):
|
def overall_mem_stats(self, device_type: str):
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
@ -106,15 +81,22 @@ class MemStatsCollector:
|
|||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def current_non_model_data(self, device_type: str) -> int:
|
def max_non_model_data(self, device_type: str) -> int:
|
||||||
"""get the non model data of the current sampling moment
|
"""Get max non model data memory usage of current sampling period
|
||||||
"""
|
|
||||||
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
|
|
||||||
|
|
||||||
def next_non_model_data(self, device_type: str):
|
Args:
|
||||||
"""get the non model data of the next sampling moment
|
device_type (str): device type, can be 'cpu' or 'cuda'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: max non model data memory usage of current sampling period
|
||||||
"""
|
"""
|
||||||
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
||||||
|
assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.'
|
||||||
|
next_period_idx = (self._period_idx + 1) % len(self._sampling_time)
|
||||||
|
current_non_model_data = self.non_model_data_list(device_type)[self._period_idx]
|
||||||
|
next_non_model_data = self.non_model_data_list(device_type)[next_period_idx]
|
||||||
|
self._period_idx = next_period_idx
|
||||||
|
return max(current_non_model_data, next_non_model_data)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sampling_time(self):
|
def sampling_time(self):
|
||||||
@ -126,6 +108,7 @@ class MemStatsCollector:
|
|||||||
|
|
||||||
def finish_collection(self):
|
def finish_collection(self):
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
|
self._mem_monitor.finish()
|
||||||
|
|
||||||
def sample_memstats(self) -> None:
|
def sample_memstats(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -134,8 +117,6 @@ class MemStatsCollector:
|
|||||||
Advance the sampling cnter.
|
Advance the sampling cnter.
|
||||||
"""
|
"""
|
||||||
if self._start_flag:
|
if self._start_flag:
|
||||||
sampling_cnt = self._sampling_cnter.current()
|
|
||||||
assert sampling_cnt == len(self._overall_cuda_list)
|
|
||||||
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||||
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
||||||
@ -146,13 +127,6 @@ class MemStatsCollector:
|
|||||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||||
self._sampling_time.append(time.time())
|
self._sampling_time.append(time.time())
|
||||||
self._mem_monitor.start()
|
self._mem_monitor.start()
|
||||||
# TODO(ver217): refactor sampler
|
|
||||||
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
|
|
||||||
self._sampling_cnter.advance()
|
|
||||||
|
|
||||||
def reset_sampling_cnter(self) -> None:
|
|
||||||
self._sampling_cnter.reset()
|
|
||||||
self._mem_monitor.finish()
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self._model_data_cuda_list = []
|
self._model_data_cuda_list = []
|
||||||
@ -162,5 +136,4 @@ class MemStatsCollector:
|
|||||||
self._overall_cpu_list = []
|
self._overall_cpu_list = []
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._sampling_cnter.reset()
|
self._period_idx = 0
|
||||||
self._mem_monitor.finish()
|
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
from async_memtracer import AsyncMemoryMonitor
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
async_mem_monitor = AsyncMemoryMonitor()
|
|
||||||
input = torch.randn(2, 20).cuda()
|
|
||||||
OP1 = torch.nn.Linear(20, 30).cuda()
|
|
||||||
OP2 = torch.nn.Linear(30, 40).cuda()
|
|
||||||
|
|
||||||
async_mem_monitor.start()
|
|
||||||
output = OP1(input)
|
|
||||||
async_mem_monitor.finish()
|
|
||||||
async_mem_monitor.start()
|
|
||||||
output = OP2(output)
|
|
||||||
async_mem_monitor.finish()
|
|
||||||
async_mem_monitor.save('log.pkl')
|
|
@ -1,37 +0,0 @@
|
|||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def test_mem_collector():
|
|
||||||
collector = MemStatsCollector()
|
|
||||||
|
|
||||||
collector.start_collection()
|
|
||||||
|
|
||||||
a = torch.randn(10).cuda()
|
|
||||||
|
|
||||||
# sampling at time 0
|
|
||||||
collector.sample_memstats()
|
|
||||||
|
|
||||||
m_a = torch.randn(10).cuda()
|
|
||||||
b = torch.randn(10).cuda()
|
|
||||||
|
|
||||||
# sampling at time 1
|
|
||||||
collector.sample_memstats()
|
|
||||||
|
|
||||||
a = b
|
|
||||||
|
|
||||||
# sampling at time 2
|
|
||||||
collector.sample_memstats()
|
|
||||||
|
|
||||||
collector.finish_collection()
|
|
||||||
collector.reset_sampling_cnter()
|
|
||||||
|
|
||||||
# do nothing after collection, just advance sampling cnter
|
|
||||||
collector.sample_memstats()
|
|
||||||
collector.sample_memstats()
|
|
||||||
|
|
||||||
print(collector.overall_mem_stats('cuda'))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_mem_collector()
|
|
@ -71,8 +71,7 @@ class StatefulTensorMgr(object):
|
|||||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
|
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
|
||||||
else:
|
else:
|
||||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||||
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
|
max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda')
|
||||||
self._mem_stats_collector.next_non_model_data('cuda'))
|
|
||||||
|
|
||||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||||
|
@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
|
|||||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||||
from colossalai.engine.gradient_handler.utils import bucket_allreduce
|
from colossalai.engine.gradient_handler.utils import bucket_allreduce
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device, disposable
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module):
|
|||||||
for param in submodule.parameters(recurse=False):
|
for param in submodule.parameters(recurse=False):
|
||||||
if hasattr(param, 'colo_attr'):
|
if hasattr(param, 'colo_attr'):
|
||||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||||
|
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||||
|
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||||
else:
|
else:
|
||||||
self._memstats_collector = None
|
self._memstats_collector = None
|
||||||
self._stateful_tensor_mgr = None
|
self._stateful_tensor_mgr = None
|
||||||
self._iter_cnter = 0
|
|
||||||
|
|
||||||
# Register hooks
|
# Register hooks
|
||||||
self._ophook_list = [
|
self._ophook_list = [
|
||||||
@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module):
|
|||||||
f.write('\n')
|
f.write('\n')
|
||||||
|
|
||||||
def _pre_forward_operations(self):
|
def _pre_forward_operations(self):
|
||||||
if self._iter_cnter == 0 and self._memstats_collector:
|
# the operation will affect the memory tracer behavior in ZeroHook
|
||||||
# the operation will affect the memory tracer behavior in ZeroHook
|
if self._memstats_collector:
|
||||||
self._memstats_collector.start_collection()
|
self._start_collect_memstats()
|
||||||
|
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if hasattr(p, 'colo_attr'):
|
if hasattr(p, 'colo_attr'):
|
||||||
@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module):
|
|||||||
ophook.post_iter()
|
ophook.post_iter()
|
||||||
|
|
||||||
def _update_memstats(self):
|
def _update_memstats(self):
|
||||||
if self._iter_cnter == 0 and self._memstats_collector:
|
|
||||||
self._memstats_collector.finish_collection()
|
|
||||||
if self._memstats_collector:
|
if self._memstats_collector:
|
||||||
self._memstats_collector.reset_sampling_cnter()
|
self._finish_collect_memstats()
|
||||||
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
||||||
# the way to calculate margin space is based on the assumption that
|
# the way to calculate margin space is based on the assumption that
|
||||||
# model data is fixed in cuda during training.
|
# model data is fixed in cuda during training.
|
||||||
# cuda margin space can be used to store OS.
|
# cuda margin space can be used to store OS.
|
||||||
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
|
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
|
||||||
self._memstats_collector.overall_mem_stats('cuda'))
|
self._memstats_collector.overall_mem_stats('cuda'))
|
||||||
self._iter_cnter += 1
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _post_backward_operations(self) -> None:
|
def _post_backward_operations(self) -> None:
|
||||||
|
@ -55,7 +55,6 @@ def run_stm():
|
|||||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||||
mem_collector.sample_memstats()
|
mem_collector.sample_memstats()
|
||||||
mem_collector.finish_collection()
|
mem_collector.finish_collection()
|
||||||
mem_collector.reset_sampling_cnter()
|
|
||||||
stateful_tensor_mgr.reset()
|
stateful_tensor_mgr.reset()
|
||||||
|
|
||||||
# warmup done
|
# warmup done
|
||||||
|
Loading…
Reference in New Issue
Block a user