[Gemini] chunk init using runtime visited param order (#2115)

This commit is contained in:
Jiarui Fang
2022-12-12 18:06:16 +08:00
committed by GitHub
parent e7d3afc9cc
commit 9214d1fe28
10 changed files with 77 additions and 29 deletions

View File

@@ -8,6 +8,7 @@ import torch.distributed as dist
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
@@ -216,8 +217,18 @@ class ZeroDDP(ColoDDP):
self.grads_device: Dict[torch.Tensor, torch.device] = {}
cpu_offload = self.gemini_manager.policy_name != 'cuda'
# TODO: get param order and filter unused params
for p in module.parameters():
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.gemini_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order = OrderedParamGenerator()
for p in module.parameters():
param_order.append(p)
for p in param_order.generate():
assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False):
@@ -243,7 +254,7 @@ class ZeroDDP(ColoDDP):
self.chunk_manager.close_all_groups()
self._cast_buffers()
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)

View File

@@ -4,6 +4,7 @@ import torch
from colossalai.gemini.chunk import init_chunk_manager
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import MemStats
from .data_parallel import ZeroDDP
@@ -18,7 +19,8 @@ class GeminiDDP(ZeroDDP):
force_outputs_fp32: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: Optional[float] = None) -> None:
min_chunk_size_mb: Optional[float] = None,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
@@ -44,11 +46,12 @@ class GeminiDDP(ZeroDDP):
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)