diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 6bf0b4019..3a0ae59fc 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -131,7 +131,7 @@ class GeminiDDP(ModelWrapper): offload_param_frac=offload_param_frac, warmup_non_model_data_ratio=warmup_non_model_data_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio, - max_prefetch=max_prefetch + max_prefetch=max_prefetch, ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 1d734bd83..e6b8cf8ef 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,10 +1,9 @@ from contextlib import contextmanager from enum import Enum from functools import partial -from typing import Dict, List, Iterable, Tuple +from typing import List import torch -import torch.distributed as dist from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook @@ -12,8 +11,6 @@ from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager -from .chunk import Chunk - class TrainingPhase(Enum): FORWARD = 0 @@ -23,7 +20,9 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") import os -rank = int(os.environ['RANK']) + +rank = int(os.environ["RANK"]) + class GeminiZeROHook(ColoParamOpHook): def __init__(self, gemini_manager: GeminiManager) -> None: @@ -32,14 +31,13 @@ class GeminiZeROHook(ColoParamOpHook): self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - def pre_op(self, params): # map params to chunks params = [p for p in params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) # wait for prefetched chunks, filter those are not prefetched - unique_chunks = set(all_chunks) + set(all_chunks) chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks) # transfer state @@ -48,7 +46,9 @@ class GeminiZeROHook(ColoParamOpHook): self._gemini_manager.sample_overall_data() # evit chunks, aware of async fetched - self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0) + self._gemini_manager.adjust_layout( + all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0 + ) # fetch the rest synchronously for chunk in chunks_fetch_sync: @@ -57,7 +57,9 @@ class GeminiZeROHook(ColoParamOpHook): # get possible chunks to prefetch chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks() if rank == 0 and not self._gemini_manager.is_warmup(): - print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}") + print( + f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}" + ) print(f"{all_chunks=}") print(f"accessed_chunks={self._chunk_manager.accessed_chunks}") print(f"{chunks_fetch_sync=}") diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 6640bf03b..11bde789c 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,6 +1,6 @@ import functools from time import time -from typing import Dict, List, Optional, Tuple, Iterable +from typing import Dict, Iterable, List, Optional, Tuple import torch import torch.distributed as dist @@ -101,7 +101,7 @@ class GeminiManager: start = time() self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks) - # don't evict chunks that are asynchronously fetched + # don't evict chunks that are asynchronously fetched can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works] self._layout_time += time() - start diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index aad97321c..e5f61a033 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -13,11 +13,17 @@ from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector + class PlacementPolicy(ABC): need_mem_stats: bool = False def __init__( - self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs + self, + gemini_manager: "GeminiManager", + chunk_manager: ChunkManager, + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + max_prefetch: int = 0, + **kwargs, ) -> None: self.gemini_manager = gemini_manager self.chunk_manager = chunk_manager @@ -38,13 +44,16 @@ class PlacementPolicy(ABC): def get_prefetch_chunks(self) -> List[Chunk]: raise NotImplementedError + import os + rank = int(os.environ["RANK"]) + class StaticPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: 'GeminiManager', + gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -53,7 +62,9 @@ class StaticPlacementPolicy(PlacementPolicy): offload_param_frac: float = 0.0, **kwargs, ) -> None: - super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) + super().__init__( + gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch + ) if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") offload_param_frac = 0.0 @@ -124,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy): def __init__( self, - gemini_manager: 'GeminiManager', + gemini_manager: "GeminiManager", chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch: int = 0, @@ -132,7 +143,9 @@ class AutoPlacementPolicy(PlacementPolicy): steady_cuda_cap_ratio: float = 0.9, **kwargs, ) -> None: - super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch) + super().__init__( + gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch + ) # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()