From 8afc001f4f98bbb38b6527d8c6aa41546d13342c Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Sun, 11 Dec 2022 21:41:13 +0800 Subject: [PATCH] [Gemini] chunk init use OrderedParamGenerator (#2110) --- colossalai/gemini/chunk/__init__.py | 2 ++ colossalai/gemini/chunk/search_utils.py | 13 +++++++++---- colossalai/gemini/memory_tracer/__init__.py | 4 ++-- .../gemini/memory_tracer/memory_stats.py | 4 ++-- .../memory_tracer/param_runtime_order.py | 18 ++++++++++++++++-- .../gemini/memory_tracer/runtime_mem_tracer.py | 2 +- 6 files changed, 32 insertions(+), 11 deletions(-) diff --git a/colossalai/gemini/chunk/__init__.py b/colossalai/gemini/chunk/__init__.py index 38117ca3e..6914d2dbe 100644 --- a/colossalai/gemini/chunk/__init__.py +++ b/colossalai/gemini/chunk/__init__.py @@ -2,3 +2,5 @@ from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState from .manager import ChunkManager from .search_utils import classify_params_by_dp_degree, search_chunk_configuration from .utils import init_chunk_manager + +__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager'] diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py index d5cd1329c..f55d87fc2 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple import numpy as np import torch.nn as nn +from colossalai.gemini.memory_tracer import OrderedParamGenerator from colossalai.tensor import ColoParameter @@ -40,20 +41,20 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: return left + acc -def classify_params_by_dp_degree(model: nn.Module) -> Dict[int, List[ColoParameter]]: +def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]: """classify_params_by_dp_degree Classify the parameters by their dp degree Args: - model (nn.Module): model + param_order (OrderedParamGenerator): the order of param be visied Returns: Dict[int, List[ColoParameter]]: a dict contains the classification results. The keys are dp_degrees and the values are parameters. """ params_dict: Dict[int, List[ColoParameter]] = dict() - for param in model.parameters(): + for param in param_order.generate(): assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" if not in_ddp(param): continue @@ -85,11 +86,15 @@ def search_chunk_configuration( Tuple[Dict, int]: chunk config and its memory chunk waste in byte. """ + param_order = OrderedParamGenerator() + for p in model.parameters(): + param_order.append(p) + search_range_byte = round(search_range_mb * 1024**2) min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) assert search_range_byte >= 0 - params_dict = classify_params_by_dp_degree(model) + params_dict = classify_params_by_dp_degree(param_order) config_dict: Dict[int, Dict] = dict() size_dict: Dict[int, List[int]] = dict() diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py index 12f6b7950..02c9d5754 100644 --- a/colossalai/gemini/memory_tracer/__init__.py +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -1,4 +1,4 @@ -from .param_runtime_order import ParamRuntimeOrder # isort:skip +from .param_runtime_order import OrderedParamGenerator # isort:skip from .memory_stats import MemStats # isort:skip from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip from .memstats_collector import MemStatsCollector # isort:skip @@ -7,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip __all__ = [ 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', - 'StaticMemStatsCollector', 'MemStats', 'ParamRuntimeOrder' + 'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator' ] diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py index 4412a580e..a374ab408 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from colossalai.gemini.memory_tracer import ParamRuntimeOrder +from colossalai.gemini.memory_tracer import OrderedParamGenerator class MemStats(object): @@ -21,7 +21,7 @@ class MemStats(object): self._non_model_data_cuda_list = [] self._non_model_data_cpu_list = [] - self._param_runtime_order = ParamRuntimeOrder() + self._param_runtime_order = OrderedParamGenerator() def append_overall_data(self, device_type: str, val: float): if device_type == 'cuda': diff --git a/colossalai/gemini/memory_tracer/param_runtime_order.py b/colossalai/gemini/memory_tracer/param_runtime_order.py index ceb13bc24..b65251373 100644 --- a/colossalai/gemini/memory_tracer/param_runtime_order.py +++ b/colossalai/gemini/memory_tracer/param_runtime_order.py @@ -1,8 +1,22 @@ +from abc import ABC + import torch -class ParamRuntimeOrder(object): - """ParamRuntimeOrder +class ParamGenerator(ABC): + + def append(self, param: torch.nn.Parameter): + pass + + def generate(self): + pass + + def clear(self): + pass + + +class OrderedParamGenerator(ParamGenerator): + """OrderedParamGenerator Contain the order of parameters visited during runtime. """ diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index 4eacb49d0..4cee5dd60 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,6 +1,6 @@ import torch.nn -from colossalai.gemini.memory_tracer import MemStats, ParamRuntimeOrder +from colossalai.gemini.memory_tracer import MemStats from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.tensor.param_op_hook import ColoParamOpHookManager