mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[Gemini] chunk init using runtime visited param order (#2115)
This commit is contained in:
@@ -8,6 +8,7 @@ import torch.distributed as dist
|
||||
|
||||
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
@@ -216,8 +217,18 @@ class ZeroDDP(ColoDDP):
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
|
||||
cpu_offload = self.gemini_manager.policy_name != 'cuda'
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
|
||||
if self.gemini_manager._premade_memstats_:
|
||||
# build chunk in param runtime visited order.
|
||||
param_order = self.gemini_manager.memstats()._param_runtime_order
|
||||
else:
|
||||
# build chunk in param initialized order.
|
||||
# Note: in this way, it can not get filter unused params during runtime.
|
||||
param_order = OrderedParamGenerator()
|
||||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
@@ -243,7 +254,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.chunk_manager.close_all_groups()
|
||||
self._cast_buffers()
|
||||
|
||||
params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
for p, fp32_p in zip(params_list, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
|
@@ -4,6 +4,7 @@ import torch
|
||||
|
||||
from colossalai.gemini.chunk import init_chunk_manager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
|
||||
from .data_parallel import ZeroDDP
|
||||
|
||||
@@ -18,7 +19,8 @@ class GeminiDDP(ZeroDDP):
|
||||
force_outputs_fp32: bool = False,
|
||||
search_range_mb: int = 32,
|
||||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: Optional[float] = None) -> None:
|
||||
min_chunk_size_mb: Optional[float] = None,
|
||||
memstats: Optional[MemStats] = None) -> None:
|
||||
"""
|
||||
A torch.Module warpper using ZeRO-DP and Genimi.
|
||||
ZeRO is for parallel. Gemini is for memory management.
|
||||
@@ -44,11 +46,12 @@ class GeminiDDP(ZeroDDP):
|
||||
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
|
||||
If the aggregate size of parameters is still samller than the minimum chunk size,
|
||||
all parameters will be compacted into one small chunk.
|
||||
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
|
||||
"""
|
||||
chunk_manager = init_chunk_manager(model=module,
|
||||
init_device=device,
|
||||
hidden_dim=hidden_dim,
|
||||
search_range_mb=search_range_mb,
|
||||
min_chunk_size_mb=min_chunk_size_mb)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
|
||||
|
Reference in New Issue
Block a user