diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index ca3165a71..04b660060 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -133,9 +133,9 @@ class GeminiManager: if self._mem_stats_collector: self._mem_stats_collector.sample_overall_data() - def sample_model_data(self): + def record_model_data_volume(self): if self._mem_stats_collector: - self._mem_stats_collector.sample_model_data() + self._mem_stats_collector.record_model_data_volume() @property def chunk_manager(self): diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py index 6c681d31f..33c0d99c8 100644 --- a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/gemini/memory_tracer/chunk_memstats_collector.py @@ -15,7 +15,7 @@ class ChunkMemStatsCollector(MemStatsCollector): self._chunk_manager = chunk_manager # override - def sample_model_data(self) -> None: + def record_model_data_volume(self) -> None: """Sampling model data statistics. """ if self._start_flag and not self.use_outside_memstats: diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py index bc215ccb9..9a1d4cc86 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -15,7 +15,7 @@ class MemStats(object): self._step_param_dict = dict() # (param, List[preop_step]) self._param_step_dict = dict() - # (preop_step, non_model_data) + # (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1) self._step_nmd_dict = dict() self._param_runtime_order = OrderedParamGenerator() @@ -23,9 +23,8 @@ class MemStats(object): self._prev_overall_cuda = -1 self._prev_md_cuda = -1 - # old version - self.param_non_model_data_map: Dict(Any, List[int]) = {} + # old version self._model_data_cuda_list = [] self._model_data_cpu_list = [] @@ -35,9 +34,12 @@ class MemStats(object): self._non_model_data_cuda_list = [] self._non_model_data_cpu_list = [] - def record_max_cuda_non_model_data(self): + def calc_max_cuda_non_model_data(self): if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1: - self._step_nmd_dict[self._preop_step] = self._prev_overall_cuda - self._prev_md_cuda + max_cuda_non_model_data = self._prev_overall_cuda - self._prev_md_cuda + self._step_nmd_dict[self._preop_step - 1] = max_cuda_non_model_data + # compatibility of the old version. + self._non_model_data_cuda_list.append(max_cuda_non_model_data) def record_max_cuda_model_data(self, val): self._prev_md_cuda = val @@ -45,12 +47,45 @@ class MemStats(object): def record_max_cuda_overall_data(self, val): self._prev_overall_cuda = val + def increase_preop_step(self, param_list: List[torch.nn.Parameter]): + """ + the time step is increased. param list is used between current and the next + time step. + + Args: + param_list (List[torch.nn.Parameter]): a list of torch paramters. + """ + for p in param_list: + if p not in self._param_step_dict: + self._param_step_dict[p] = [self._preop_step] + else: + self._param_step_dict[p].append(self._preop_step) + self._param_runtime_order.append(p) + self._step_param_dict[self._preop_step] = param_list + self._preop_step += 1 + + def param_used_step(self, param: torch.nn.Parameter) -> Optional[List[int]]: + """param_used_step + get the timestep list using the param + + Args: + param (torch.nn.Parameter): a torch param + + Returns: + Optional[List[int]]: a list of int indicates the time step of preop hook. + """ + if param not in self._param_step_dict: + return None + else: + return self._param_step_dict[param] + def param_order(self): if self._param_runtime_order.is_empty(): raise RuntimeError else: return self._param_runtime_order + ## APIs to be depracated def append_overall_data(self, device_type: str, val: float): if device_type == 'cuda': self._overall_cuda_list.append(val) @@ -135,38 +170,6 @@ class MemStats(object): else: raise TypeError - def increase_preop_step(self, param_list: List[torch.nn.Parameter]): - """ - the time step is increased. param list is used between current and the next - time step. - - Args: - param_list (List[torch.nn.Parameter]): a list of torch paramters. - """ - for p in param_list: - if p not in self._param_step_dict: - self._param_step_dict[p] = [self._preop_step] - else: - self._param_step_dict[p].append(self._preop_step) - self._param_runtime_order.append(p) - self._step_param_dict[self._preop_step] = param_list - self._preop_step += 1 - - def param_used_timestep(self, param: torch.nn.Parameter) -> Optional[List[int]]: - """param_used_timestep - get the timestep list using the param - - Args: - param (torch.nn.Parameter): a torch param - - Returns: - Optional[List[int]]: a list of int indicates the time step of preop hook. - """ - if param not in self._param_step_dict: - return None - else: - return self._param_step_dict[param] - def clear(self): self._model_data_cuda_list = [] self._overall_cuda_list = [] diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py index a81961227..4db03444f 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/gemini/memory_tracer/memstats_collector.py @@ -69,7 +69,7 @@ class MemStatsCollector: self._start_flag = False self._mem_monitor.finish() - def sample_model_data(self) -> None: + def record_model_data_volume(self) -> None: """Sampling model data statistics. """ if self._start_flag and not self.use_outside_memstats: diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index 4cee5dd60..a643751da 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -82,7 +82,9 @@ class RuntimeMemTracer(): def _post_backward(self): cuda_volume = self.param_op_hook.mem_monitor.finish() - self._memstats.append_non_model_data('cuda', cuda_volume - self._memstats.last_model_data('cuda')) + self._memstats.record_max_cuda_overall_data(cuda_volume) + # calc the last Op non model data + self._memstats.calc_max_cuda_non_model_data() self.grad_hook.remove_grad_hook() self._restore_params() diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py index 1ff259762..6d0df4e61 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -86,7 +86,7 @@ class ParamMemTracerHook(ColoParamOpHook): elif cur_dev == "cuda": alloc_storage(p.data) - def sample_model_data(self, params): + def record_model_data_volume(self, params): """ get cuda model data used by params """ @@ -100,21 +100,19 @@ class ParamMemTracerHook(ColoParamOpHook): if not self._grad_stats.unreleased_grad_flag[p]: self._grad_stats.unreleased_grad_volume += cur_model_data_volume self._grad_stats.unreleased_grad_flag[p] = True - self._memstats.append_model_data('cuda', data_volume) # record max non model data used for this Op self._memstats.record_max_cuda_model_data(data_volume) def pre_op(self, params): - # get overall cuda data. - max_cuda_vol_of_period = self.mem_monitor.finish() - # record max cuda overall data for prev Op. - self._memstats.record_max_cuda_overall_data(max_cuda_vol_of_period) - self._memstats.record_max_cuda_non_model_data() - max_cuda_model_data_val = self._memstats.last_model_data('cuda') - if max_cuda_model_data_val is not None: - self._memstats.append_non_model_data('cuda', max_cuda_vol_of_period - max_cuda_model_data_val) + max_cuda_used_pre_op = self.mem_monitor.finish() + # record max cuda overall data for prev OP. + self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op) + # record max cuda non model data for prev OP. + self._memstats.calc_max_cuda_non_model_data() + self._allocate_params_on_cuda(params) - self.sample_model_data(params) + # record max cuda model data for current OP + self.record_model_data_volume(params) self.mem_monitor.start() self._memstats.increase_preop_step(params) diff --git a/colossalai/zero/utils/gemini_hook.py b/colossalai/zero/utils/gemini_hook.py index 99ca38495..5f34410a8 100644 --- a/colossalai/zero/utils/gemini_hook.py +++ b/colossalai/zero/utils/gemini_hook.py @@ -32,7 +32,7 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager.adjust_layout(chunks) for chunk in chunks: self._chunk_manager.access_chunk(chunk) - self._gemini_manager.sample_model_data() + self._gemini_manager.record_model_data_volume() def post_op(self, params): params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)] diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py index fa46de146..87bf2c0f5 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -67,7 +67,7 @@ class ZeroHook(BaseOpHook): # record model data statistics if self._memstarts_collector: - self._memstarts_collector.sample_model_data() + self._memstarts_collector.record_model_data_volume() def pre_fwd_exec(self, module: torch.nn.Module, *args): self.adjust_module_data(module) diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py index 82439144b..3b1ce21c0 100644 --- a/tests/test_gemini/update/test_gemini_use_rmt.py +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -47,7 +47,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) print('runtime tracer: ', runtime_tracer_non_model_data) - print([memstats.param_used_timestep(p) for p in model.parameters()]) + print([memstats.param_used_step(p) for p in model.parameters()]) + + if model_name == 'repeated_computed_layers': + for idx, p in enumerate(model.parameters()): + step_list = memstats.param_used_step(p) + if idx < 4: + assert len(step_list) == 4 if model_name == 'repeated_computed_layers': for idx, p in enumerate(model.parameters()):