From 25abae6d7ffacd98f871c4fd5fb28b5d745f452b Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 6 Dec 2022 19:48:20 +0800 Subject: [PATCH] [Gemini] use MemStats in Runtime Memory tracer (#2088) --- colossalai/gemini/memory_tracer/__init__.py | 3 ++- colossalai/gemini/memory_tracer/memory_stats.py | 2 ++ .../gemini/memory_tracer/runtime_mem_tracer.py | 15 ++++++++++----- .../gemini/ophooks/runtime_mem_tracer_hook.py | 11 +++++++---- tests/test_gemini/test_runtime_mem_tracer.py | 6 +++--- 5 files changed, 24 insertions(+), 13 deletions(-) diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py index d12461353..5afe6e4ff 100644 --- a/colossalai/gemini/memory_tracer/__init__.py +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -3,8 +3,9 @@ from .memstats_collector import MemStatsCollector # isort:skip from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip from .static_memstats_collector import StaticMemStatsCollector # isort:skip +from .memory_stats import MemStats __all__ = [ 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', - 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER' + 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemStats' ] diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py index 2bb859683..fcd2ba8d4 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -36,6 +36,8 @@ class MemStats(object): raise TypeError def append_non_model_data(self, device_type: str): + if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0: + return if device_type == 'cuda': self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) elif device_type == 'cpu': diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index 3b16686c7..275a88335 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,5 +1,6 @@ import torch.nn +from colossalai.gemini.memory_tracer import MemStats from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook from colossalai.nn.parallel.data_parallel import _cast_float @@ -24,7 +25,8 @@ class RuntimeMemTracer(): super().__init__() self.module = module self.dtype = dtype - self.param_op_hook = ParamMemTracerHook() + self._memstats = MemStats() + self.param_op_hook = ParamMemTracerHook(self._memstats) self.grad_hook = GradMemTracerHook(module) self.cpu_param_data_dict = {} @@ -74,14 +76,17 @@ class RuntimeMemTracer(): def _post_backward(self): cuda_volume = self.param_op_hook.mem_monitor.finish() - last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1] - GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data) + self._memstats.append_model_data('cuda', cuda_volume) + self._memstats.append_non_model_data('cuda') + # last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1] + # GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data) self.grad_hook.remove_grad_hook() self._restore_params() def _clear_cuda_mem_info(self): - GLOBAL_CUDA_MEM_INFO.model_data_list.clear() - GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear() + # GLOBAL_CUDA_MEM_INFO.model_data_list.clear() + # GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear() + self._memstats.clear() GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear() GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0 diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py index 5d8382ed0..55362f888 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -41,9 +41,10 @@ class GradMemTracerHook(): class ParamMemTracerHook(ColoParamOpHook): - def __init__(self) -> None: + def __init__(self, memstats) -> None: super().__init__() self._training_phase = TrainingPhase.FORWARD + self._memstats = memstats self.mem_monitor = SyncCudaMemoryMonitor() def _free_cuda_params(self, params): @@ -76,12 +77,14 @@ class ParamMemTracerHook(ColoParamOpHook): if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]: GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True - GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume) + # GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume) + self._memstats.append_model_data('cuda', data_volume) def pre_op(self, params): cuda_volume = self.mem_monitor.finish() - if len(GLOBAL_CUDA_MEM_INFO.model_data_list): - GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1]) + self._memstats.append_model_data('cuda', cuda_volume) + # if len(GLOBAL_CUDA_MEM_INFO.model_data_list): + # GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1]) self._allocate_params_on_cuda(params) self.sample_model_data(params) self.mem_monitor.start() diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py index ff55ac54d..34c200e05 100644 --- a/tests/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -3,7 +3,6 @@ from copy import deepcopy import numpy as np import torch -from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.utils.model.colo_init_context import ColoInitContext from tests.components_to_test import run_fwd_bwd @@ -34,9 +33,10 @@ def test_runtime_mem_tracer(): for p1, p2 in zip(model_bk.parameters(), model.parameters()): torch.allclose(p1.to(torch.half), p2) - cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024**2 + non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list('cuda') + cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2 print("cuda_non_model_data_list", len(cuda_non_model_data_list)) - print(GLOBAL_CUDA_MEM_INFO.non_model_data_list) + print(non_model_data_list) del model