[Gemini] chunk init use OrderedParamGenerator (#2110)

This commit is contained in:
Jiarui Fang 2022-12-11 21:41:13 +08:00 committed by GitHub
parent 63fbba3c19
commit 8afc001f4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 32 additions and 11 deletions

View File

@ -2,3 +2,5 @@ from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from .manager import ChunkManager
from .search_utils import classify_params_by_dp_degree, search_chunk_configuration
from .utils import init_chunk_manager
__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager']

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Tuple
import numpy as np
import torch.nn as nn
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.tensor import ColoParameter
@ -40,20 +41,20 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc
def classify_params_by_dp_degree(model: nn.Module) -> Dict[int, List[ColoParameter]]:
def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
Args:
model (nn.Module): model
param_order (OrderedParamGenerator): the order of param be visied
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
The keys are dp_degrees and the values are parameters.
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
continue
@ -85,11 +86,15 @@ def search_chunk_configuration(
Tuple[Dict, int]: chunk config and its memory chunk waste in byte.
"""
param_order = OrderedParamGenerator()
for p in model.parameters():
param_order.append(p)
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
params_dict = classify_params_by_dp_degree(model)
params_dict = classify_params_by_dp_degree(param_order)
config_dict: Dict[int, Dict] = dict()
size_dict: Dict[int, List[int]] = dict()

View File

@ -1,4 +1,4 @@
from .param_runtime_order import ParamRuntimeOrder # isort:skip
from .param_runtime_order import OrderedParamGenerator # isort:skip
from .memory_stats import MemStats # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memstats_collector import MemStatsCollector # isort:skip
@ -7,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'MemStats', 'ParamRuntimeOrder'
'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator'
]

View File

@ -1,6 +1,6 @@
from typing import Any, Dict, List
from colossalai.gemini.memory_tracer import ParamRuntimeOrder
from colossalai.gemini.memory_tracer import OrderedParamGenerator
class MemStats(object):
@ -21,7 +21,7 @@ class MemStats(object):
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._param_runtime_order = ParamRuntimeOrder()
self._param_runtime_order = OrderedParamGenerator()
def append_overall_data(self, device_type: str, val: float):
if device_type == 'cuda':

View File

@ -1,8 +1,22 @@
from abc import ABC
import torch
class ParamRuntimeOrder(object):
"""ParamRuntimeOrder
class ParamGenerator(ABC):
def append(self, param: torch.nn.Parameter):
pass
def generate(self):
pass
def clear(self):
pass
class OrderedParamGenerator(ParamGenerator):
"""OrderedParamGenerator
Contain the order of parameters visited during runtime.
"""

View File

@ -1,6 +1,6 @@
import torch.nn
from colossalai.gemini.memory_tracer import MemStats, ParamRuntimeOrder
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager