mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[zero] refactor memstats collector (#706)
* refactor memstats collector * fix disposable * polish code
This commit is contained in:
@@ -71,8 +71,7 @@ class StatefulTensorMgr(object):
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
|
||||
self._mem_stats_collector.next_non_model_data('cuda'))
|
||||
max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda')
|
||||
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
|
@@ -12,7 +12,7 @@ from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.engine.gradient_handler.utils import bucket_allreduce
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
@@ -112,10 +112,11 @@ class ShardedModelV2(nn.Module):
|
||||
for param in submodule.parameters(recurse=False):
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
self._stateful_tensor_mgr = None
|
||||
self._iter_cnter = 0
|
||||
|
||||
# Register hooks
|
||||
self._ophook_list = [
|
||||
@@ -188,9 +189,9 @@ class ShardedModelV2(nn.Module):
|
||||
f.write('\n')
|
||||
|
||||
def _pre_forward_operations(self):
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
self._memstats_collector.start_collection()
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
if self._memstats_collector:
|
||||
self._start_collect_memstats()
|
||||
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'colo_attr'):
|
||||
@@ -221,17 +222,14 @@ class ShardedModelV2(nn.Module):
|
||||
ophook.post_iter()
|
||||
|
||||
def _update_memstats(self):
|
||||
if self._iter_cnter == 0 and self._memstats_collector:
|
||||
self._memstats_collector.finish_collection()
|
||||
if self._memstats_collector:
|
||||
self._memstats_collector.reset_sampling_cnter()
|
||||
self._finish_collect_memstats()
|
||||
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
||||
# the way to calculate margin space is based on the assumption that
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_cuda_memory_capacity() - max(
|
||||
self._memstats_collector.overall_mem_stats('cuda'))
|
||||
self._iter_cnter += 1
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
|
Reference in New Issue
Block a user