[zero] refactor memstats collector (#706)

* refactor memstats collector

* fix disposable

* polish code
This commit is contained in:
ver217
2022-04-11 10:46:08 +08:00
committed by GitHub
parent 3fc8a204dc
commit ab8c6b4a0e
8 changed files with 44 additions and 114 deletions

View File

@@ -71,8 +71,7 @@ class StatefulTensorMgr(object):
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
self._mem_stats_collector.next_non_model_data('cuda'))
max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda')
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data

View File

@@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.engine.gradient_handler.utils import bucket_allreduce
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, disposable
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
@@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module):
for param in submodule.parameters(recurse=False):
if hasattr(param, 'colo_attr'):
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else:
self._memstats_collector = None
self._stateful_tensor_mgr = None
self._iter_cnter = 0
# Register hooks
self._ophook_list = [
@@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module):
f.write('\n')
def _pre_forward_operations(self):
if self._iter_cnter == 0 and self._memstats_collector:
# the operation will affect the memory tracer behavior in ZeroHook
self._memstats_collector.start_collection()
# the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector:
self._start_collect_memstats()
for p in self.module.parameters():
if hasattr(p, 'colo_attr'):
@@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module):
ophook.post_iter()
def _update_memstats(self):
if self._iter_cnter == 0 and self._memstats_collector:
self._memstats_collector.finish_collection()
if self._memstats_collector:
self._memstats_collector.reset_sampling_cnter()
self._finish_collect_memstats()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
self._memstats_collector.overall_mem_stats('cuda'))
self._iter_cnter += 1
@torch.no_grad()
def _post_backward_operations(self) -> None: