[zero] refactor memstats collector (#706)

* refactor memstats collector

* fix disposable

* polish code
This commit is contained in:
ver217 2022-04-11 10:46:08 +08:00 committed by GitHub
parent 3fc8a204dc
commit ab8c6b4a0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 44 additions and 114 deletions

View File

@ -5,7 +5,7 @@ from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_paral
ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier, is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank, param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param) sync_model_param, disposable)
from .data_sampler import DataParallelSampler, get_dataloader from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient from .gradient_accumulation import accumulate_gradient
from .memory_utils.memory_monitor import report_memory_usage from .memory_utils.memory_monitor import report_memory_usage
@ -19,5 +19,5 @@ __all__ = [
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint',
'ensure_path_exists' 'ensure_path_exists', 'disposable'
] ]

View File

@ -4,8 +4,8 @@ import os
import random import random
import socket import socket
from pathlib import Path from pathlib import Path
from typing import List, Union from typing import Callable, List, Union
import functools
import torch import torch
from torch._six import inf from torch._six import inf
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -112,6 +112,7 @@ def conditional_context(context_manager, enable=True):
class model_branch_context(object): class model_branch_context(object):
def __enter__(self): def __enter__(self):
self.env_status = env.save() self.env_status = env.save()
@ -328,3 +329,16 @@ def switch_virtual_pipeline_parallel_rank(rank):
yield yield
finally: finally:
gpc.set_virtual_pipeline_parallel_rank(prev_rank) gpc.set_virtual_pipeline_parallel_rank(prev_rank)
def disposable(func: Callable) -> Callable:
executed = False
@functools.wraps(func)
def wrapper(*args, **kwargs):
nonlocal executed
if not executed:
executed = True
return func(*args, **kwargs)
return wrapper

View File

@ -1,36 +1,11 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_device_memory_used from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
import torch import torch
import time import time
from typing import List from typing import List
class SamplingCounter:
def __init__(self) -> None:
self._samplint_cnt = 0
self._max_sampling_cnt = None
def advance(self):
self._samplint_cnt += 1
def next(self):
assert self._max_sampling_cnt is not None
return (self._samplint_cnt + 1) % self._max_sampling_cnt
def current(self):
return self._samplint_cnt
def max(self):
return self._max_sampling_cnt
def reset(self):
self._max_sampling_cnt = self._samplint_cnt
self._samplint_cnt = 0
class MemStatsCollector: class MemStatsCollector:
""" """
A Memory statistic collector. A Memory statistic collector.
@ -44,7 +19,6 @@ class MemStatsCollector:
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._sampling_cnter = SamplingCounter()
self._mem_monitor = AsyncMemoryMonitor() self._mem_monitor = AsyncMemoryMonitor()
self._model_data_cuda_list = [] self._model_data_cuda_list = []
self._overall_cuda_list = [] self._overall_cuda_list = []
@ -57,6 +31,7 @@ class MemStatsCollector:
self._sampling_time = [] self._sampling_time = []
self._start_flag = False self._start_flag = False
self._period_idx = 0
def overall_mem_stats(self, device_type: str): def overall_mem_stats(self, device_type: str):
if device_type == 'cuda': if device_type == 'cuda':
@ -106,15 +81,22 @@ class MemStatsCollector:
else: else:
raise TypeError raise TypeError
def current_non_model_data(self, device_type: str) -> int: def max_non_model_data(self, device_type: str) -> int:
"""get the non model data of the current sampling moment """Get max non model data memory usage of current sampling period
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
def next_non_model_data(self, device_type: str): Args:
"""get the non model data of the next sampling moment device_type (str): device type, can be 'cpu' or 'cuda'.
Returns:
int: max non model data memory usage of current sampling period
""" """
return self.non_model_data_list(device_type)[self._sampling_cnter.next()] assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.'
next_period_idx = (self._period_idx + 1) % len(self._sampling_time)
current_non_model_data = self.non_model_data_list(device_type)[self._period_idx]
next_non_model_data = self.non_model_data_list(device_type)[next_period_idx]
self._period_idx = next_period_idx
return max(current_non_model_data, next_non_model_data)
@property @property
def sampling_time(self): def sampling_time(self):
@ -126,6 +108,7 @@ class MemStatsCollector:
def finish_collection(self): def finish_collection(self):
self._start_flag = False self._start_flag = False
self._mem_monitor.finish()
def sample_memstats(self) -> None: def sample_memstats(self) -> None:
""" """
@ -134,8 +117,6 @@ class MemStatsCollector:
Advance the sampling cnter. Advance the sampling cnter.
""" """
if self._start_flag: if self._start_flag:
sampling_cnt = self._sampling_cnter.current()
assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage) self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(self._mem_monitor.finish()) self._overall_cuda_list.append(self._mem_monitor.finish())
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1]) self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
@ -146,13 +127,6 @@ class MemStatsCollector:
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._sampling_time.append(time.time()) self._sampling_time.append(time.time())
self._mem_monitor.start() self._mem_monitor.start()
# TODO(ver217): refactor sampler
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None:
self._sampling_cnter.reset()
self._mem_monitor.finish()
def clear(self) -> None: def clear(self) -> None:
self._model_data_cuda_list = [] self._model_data_cuda_list = []
@ -162,5 +136,4 @@ class MemStatsCollector:
self._overall_cpu_list = [] self._overall_cpu_list = []
self._start_flag = False self._start_flag = False
self._sampling_cnter.reset() self._period_idx = 0
self._mem_monitor.finish()

View File

@ -1,16 +0,0 @@
from async_memtracer import AsyncMemoryMonitor
import torch
if __name__ == '__main__':
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')

View File

@ -1,37 +0,0 @@
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
import torch
def test_mem_collector():
collector = MemStatsCollector()
collector.start_collection()
a = torch.randn(10).cuda()
# sampling at time 0
collector.sample_memstats()
m_a = torch.randn(10).cuda()
b = torch.randn(10).cuda()
# sampling at time 1
collector.sample_memstats()
a = b
# sampling at time 2
collector.sample_memstats()
collector.finish_collection()
collector.reset_sampling_cnter()
# do nothing after collection, just advance sampling cnter
collector.sample_memstats()
collector.sample_memstats()
print(collector.overall_mem_stats('cuda'))
if __name__ == '__main__':
test_mem_collector()

View File

@ -71,8 +71,7 @@ class StatefulTensorMgr(object):
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
else: else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'), max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda')
self._mem_stats_collector.next_non_model_data('cuda'))
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data

View File

@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.engine.gradient_handler.utils import bucket_allreduce from colossalai.engine.gradient_handler.utils import bucket_allreduce
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device, disposable
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module):
for param in submodule.parameters(recurse=False): for param in submodule.parameters(recurse=False):
if hasattr(param, 'colo_attr'): if hasattr(param, 'colo_attr'):
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else: else:
self._memstats_collector = None self._memstats_collector = None
self._stateful_tensor_mgr = None self._stateful_tensor_mgr = None
self._iter_cnter = 0
# Register hooks # Register hooks
self._ophook_list = [ self._ophook_list = [
@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module):
f.write('\n') f.write('\n')
def _pre_forward_operations(self): def _pre_forward_operations(self):
if self._iter_cnter == 0 and self._memstats_collector:
# the operation will affect the memory tracer behavior in ZeroHook # the operation will affect the memory tracer behavior in ZeroHook
self._memstats_collector.start_collection() if self._memstats_collector:
self._start_collect_memstats()
for p in self.module.parameters(): for p in self.module.parameters():
if hasattr(p, 'colo_attr'): if hasattr(p, 'colo_attr'):
@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module):
ophook.post_iter() ophook.post_iter()
def _update_memstats(self): def _update_memstats(self):
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector: if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter() self._finish_collect_memstats()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used. # cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that # the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training. # model data is fixed in cuda during training.
# cuda margin space can be used to store OS. # cuda margin space can be used to store OS.
self._cuda_margin_space = colo_cuda_memory_capacity() - max( self._cuda_margin_space = colo_cuda_memory_capacity() - max(
self._memstats_collector.overall_mem_stats('cuda')) self._memstats_collector.overall_mem_stats('cuda'))
self._iter_cnter += 1
@torch.no_grad() @torch.no_grad()
def _post_backward_operations(self) -> None: def _post_backward_operations(self) -> None:

View File

@ -55,7 +55,6 @@ def run_stm():
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr) apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats() mem_collector.sample_memstats()
mem_collector.finish_collection() mem_collector.finish_collection()
mem_collector.reset_sampling_cnter()
stateful_tensor_mgr.reset() stateful_tensor_mgr.reset()
# warmup done # warmup done