mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
refactor the code structure to solve the circular import
This commit is contained in:
@@ -5,13 +5,13 @@ from time import time
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.chunk import Chunk
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_mgr import GeminiManager
|
||||
from .memory_tracer import ChunkMemStatsCollector
|
||||
|
||||
|
||||
@@ -20,13 +20,11 @@ class PlacementPolicy(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gemini_manager: "GeminiManager", # TODO @botbw: solve circular import
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
max_prefetch: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = chunk_manager
|
||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||
self.max_prefetch = max_prefetch
|
||||
@@ -41,14 +39,15 @@ class PlacementPolicy(ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||
def get_prefetch_chunks(
|
||||
self, is_warmup, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
|
||||
) -> List[Chunk]:
|
||||
return [] # no prefetch by default
|
||||
|
||||
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
gemini_manager: "GeminiManager",
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
max_prefetch: int = 0,
|
||||
@@ -57,9 +56,7 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||
offload_param_frac: float = 0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
||||
)
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
|
||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
||||
offload_param_frac = 0.0
|
||||
@@ -110,13 +107,15 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||
|
||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
||||
def get_prefetch_chunks(
|
||||
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
|
||||
) -> List[Chunk]:
|
||||
if is_warmup: # no prefetch during warmup since we need compute_list
|
||||
return []
|
||||
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
||||
can_prefetch = self.max_prefetch - len(async_works)
|
||||
prefetch = []
|
||||
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
||||
for chunk in self.gemini_manager.compute_list[i]:
|
||||
for i in range(compute_idx + 1, len(compute_list)):
|
||||
for chunk in compute_list[i]:
|
||||
if len(prefetch) >= can_prefetch:
|
||||
break
|
||||
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||
@@ -132,7 +131,6 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gemini_manager: GeminiManager,
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
max_prefetch: int = 0,
|
||||
@@ -140,9 +138,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
steady_cuda_cap_ratio: float = 0.9,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
||||
)
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
@@ -233,8 +229,10 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
else:
|
||||
grads_device_map[p] = torch.device("cpu")
|
||||
|
||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
||||
def get_prefetch_chunks(
|
||||
self, is_warmup: bool, compute_list: tuple, compute_idx: int, async_works: Dict[Chunk, dist.Work]
|
||||
) -> List[Chunk]:
|
||||
if is_warmup: # no prefetch during warmup since we need compute_list
|
||||
return []
|
||||
# modified from self.evict_tensors
|
||||
cuda_capacity = self._steady_cuda_cap_ratio * colo_device_memory_capacity(
|
||||
@@ -246,14 +244,14 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
|
||||
prefetch_chunk_memory = 0
|
||||
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
||||
can_prefetch = self.max_prefetch - len(async_works)
|
||||
prefetch = []
|
||||
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
||||
for chunk in self.gemini_manager.compute_list[i]:
|
||||
chunk: Chunk
|
||||
for i in range(compute_idx + 1, len(compute_list)):
|
||||
for chunk in compute_list[i]:
|
||||
if len(prefetch) >= can_prefetch or prefetch_chunk_memory + chunk.chunk_mem > avail_cuda_model_data:
|
||||
break
|
||||
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||
prefetch_chunk_memory += chunk.chunk_mem
|
||||
prefetch.append(chunk)
|
||||
else:
|
||||
continue
|
||||
|
Reference in New Issue
Block a user