[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2024-05-16 07:26:19 +00:00
parent 82b25524ff
commit 6bbe956316
4 changed files with 32 additions and 17 deletions

View File

@@ -13,11 +13,17 @@ from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
class PlacementPolicy(ABC):
need_mem_stats: bool = False
def __init__(
self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs
self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
**kwargs,
) -> None:
self.gemini_manager = gemini_manager
self.chunk_manager = chunk_manager
@@ -38,13 +44,16 @@ class PlacementPolicy(ABC):
def get_prefetch_chunks(self) -> List[Chunk]:
raise NotImplementedError
import os
rank = int(os.environ["RANK"])
class StaticPlacementPolicy(PlacementPolicy):
def __init__(
self,
gemini_manager: 'GeminiManager',
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
@@ -53,7 +62,9 @@ class StaticPlacementPolicy(PlacementPolicy):
offload_param_frac: float = 0.0,
**kwargs,
) -> None:
super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0
@@ -124,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy):
def __init__(
self,
gemini_manager: 'GeminiManager',
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
@@ -132,7 +143,9 @@ class AutoPlacementPolicy(PlacementPolicy):
steady_cuda_cap_ratio: float = 0.9,
**kwargs,
) -> None:
super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()