[chore] refactor & sync

This commit is contained in:
hxwang
2024-05-16 07:22:10 +00:00
parent 4148ceed9f
commit 2e68eebdfe
7 changed files with 82 additions and 46 deletions

View File

@@ -13,15 +13,16 @@ from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(
self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs
self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs
) -> None:
self.gemini_manager = gemini_manager
self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
self.max_prefetch = max_prefetch
@abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
@@ -34,21 +35,25 @@ class PlacementPolicy(ABC):
raise NotImplementedError
@abstractmethod
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
def get_prefetch_chunks(self) -> List[Chunk]:
raise NotImplementedError
import os
rank = int(os.environ["RANK"])
class StaticPlacementPolicy(PlacementPolicy):
def __init__(
self,
gemini_manager: 'GeminiManager',
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0,
**kwargs,
) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0
@@ -99,15 +104,17 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
def get_prefetch_chunks(self) -> List[Chunk]:
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
return []
prefetch = []
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)):
for chunk in self.chunk_manager.compute_list[i]:
if len(prefetch) >= max_prefetch:
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
for chunk in self.gemini_manager.compute_list[i]:
if len(prefetch) >= self.max_prefetch:
break
if chunk not in prefetch:
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch.append(chunk)
if len(prefetch) >= max_prefetch:
if len(prefetch) >= self.max_prefetch:
break
return prefetch
@@ -117,13 +124,15 @@ class AutoPlacementPolicy(PlacementPolicy):
def __init__(
self,
gemini_manager: 'GeminiManager',
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9,
**kwargs,
) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()