mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 08:17:57 +00:00
[Gemini] chunk init use OrderedParamGenerator (#2110)
This commit is contained in:
parent
63fbba3c19
commit
8afc001f4f
@ -2,3 +2,5 @@ from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
|
||||
from .manager import ChunkManager
|
||||
from .search_utils import classify_params_by_dp_degree, search_chunk_configuration
|
||||
from .utils import init_chunk_manager
|
||||
|
||||
__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager']
|
||||
|
@ -4,6 +4,7 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
|
||||
@ -40,20 +41,20 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
||||
return left + acc
|
||||
|
||||
|
||||
def classify_params_by_dp_degree(model: nn.Module) -> Dict[int, List[ColoParameter]]:
|
||||
def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]:
|
||||
"""classify_params_by_dp_degree
|
||||
|
||||
Classify the parameters by their dp degree
|
||||
|
||||
Args:
|
||||
model (nn.Module): model
|
||||
param_order (OrderedParamGenerator): the order of param be visied
|
||||
|
||||
Returns:
|
||||
Dict[int, List[ColoParameter]]: a dict contains the classification results.
|
||||
The keys are dp_degrees and the values are parameters.
|
||||
"""
|
||||
params_dict: Dict[int, List[ColoParameter]] = dict()
|
||||
for param in model.parameters():
|
||||
for param in param_order.generate():
|
||||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
|
||||
if not in_ddp(param):
|
||||
continue
|
||||
@ -85,11 +86,15 @@ def search_chunk_configuration(
|
||||
Tuple[Dict, int]: chunk config and its memory chunk waste in byte.
|
||||
"""
|
||||
|
||||
param_order = OrderedParamGenerator()
|
||||
for p in model.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
search_range_byte = round(search_range_mb * 1024**2)
|
||||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
|
||||
assert search_range_byte >= 0
|
||||
|
||||
params_dict = classify_params_by_dp_degree(model)
|
||||
params_dict = classify_params_by_dp_degree(param_order)
|
||||
config_dict: Dict[int, Dict] = dict()
|
||||
|
||||
size_dict: Dict[int, List[int]] = dict()
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .param_runtime_order import ParamRuntimeOrder # isort:skip
|
||||
from .param_runtime_order import OrderedParamGenerator # isort:skip
|
||||
from .memory_stats import MemStats # isort:skip
|
||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||
from .memstats_collector import MemStatsCollector # isort:skip
|
||||
@ -7,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||
|
||||
__all__ = [
|
||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||
'StaticMemStatsCollector', 'MemStats', 'ParamRuntimeOrder'
|
||||
'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator'
|
||||
]
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from colossalai.gemini.memory_tracer import ParamRuntimeOrder
|
||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||
|
||||
|
||||
class MemStats(object):
|
||||
@ -21,7 +21,7 @@ class MemStats(object):
|
||||
self._non_model_data_cuda_list = []
|
||||
self._non_model_data_cpu_list = []
|
||||
|
||||
self._param_runtime_order = ParamRuntimeOrder()
|
||||
self._param_runtime_order = OrderedParamGenerator()
|
||||
|
||||
def append_overall_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
|
@ -1,8 +1,22 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ParamRuntimeOrder(object):
|
||||
"""ParamRuntimeOrder
|
||||
class ParamGenerator(ABC):
|
||||
|
||||
def append(self, param: torch.nn.Parameter):
|
||||
pass
|
||||
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
|
||||
class OrderedParamGenerator(ParamGenerator):
|
||||
"""OrderedParamGenerator
|
||||
|
||||
Contain the order of parameters visited during runtime.
|
||||
"""
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch.nn
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStats, ParamRuntimeOrder
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
|
Loading…
Reference in New Issue
Block a user