mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[zero] reorganize zero/gemini folder structure (#3424)
* [zero] refactor low-level zero folder structure * [zero] fix legacy zero import path * [zero] fix legacy zero import path * [zero] remove useless import * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] fix test import path * [zero] fix test * [zero] fix circular import * [zero] update import
This commit is contained in:
164
colossalai/zero/gemini/gemini_mgr.py
Normal file
164
colossalai/zero/gemini/gemini_mgr.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import functools
|
||||
from time import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
|
||||
|
||||
class GeminiManager:
|
||||
"""
|
||||
Stateful Tensor Manager, inspired from PatrickStar
|
||||
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
https://arxiv.org/abs/2108.05818
|
||||
|
||||
Args:
|
||||
placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
|
||||
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
|
||||
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
|
||||
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
||||
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
||||
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
|
||||
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||
|
||||
assert placement_policy in PlacementPolicyFactory.get_policy_names()
|
||||
self.policy_name = placement_policy
|
||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||
self._chunk_manager = chunk_manager
|
||||
|
||||
self._premade_memstats_ = memstats is not None
|
||||
self._memstats = memstats
|
||||
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
|
||||
self._memstats) if policy_cls.need_mem_stats else None
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
|
||||
self._h2d_volume = 0
|
||||
self._d2h_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
self._warmup = True
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
def reset_attributes(self):
|
||||
self._compute_idx = -1
|
||||
self._h2d_volume = 0
|
||||
self._d2h_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
@property
|
||||
def need_warmup(self) -> bool:
|
||||
return self.policy_name in ('auto', 'const')
|
||||
|
||||
def is_warmup(self):
|
||||
return self._warmup
|
||||
|
||||
def memstats(self):
|
||||
"""memstats
|
||||
|
||||
get the memory statistics during training.
|
||||
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
|
||||
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
||||
"""
|
||||
if self._premade_memstats_:
|
||||
return self._memstats
|
||||
else:
|
||||
assert not self._warmup, "Gemini Manager has memstats after warm up! Now is during warmup."
|
||||
return self._mem_stats_collector._memstats
|
||||
|
||||
def pre_iter(self, *args):
|
||||
if self._mem_stats_collector and self._warmup:
|
||||
self._mem_stats_collector.start_collection()
|
||||
|
||||
def post_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
if self._mem_stats_collector and self._warmup:
|
||||
self._mem_stats_collector.finish_collection()
|
||||
self._warmup = False
|
||||
self.reset_attributes()
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
""" Adjust the layout of stateful tensors according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
start = time()
|
||||
self._record_chunks_order(chunks)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||
self._layout_time += time() - start
|
||||
|
||||
vol, evict_time = self._placement_policy.evict_tensors(can_evict_chunks=hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
|
||||
self._d2h_volume += vol
|
||||
self._evict_time += evict_time
|
||||
# move COMPUTE tensors to CUDA
|
||||
self._h2d_volume += cuda_demand
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
|
||||
start = time()
|
||||
cuda_demand = 0
|
||||
for chunk in chunks:
|
||||
if chunk.device_type == 'cuda':
|
||||
if chunk.is_gathered:
|
||||
pass
|
||||
else:
|
||||
cuda_demand += chunk.chunk_mem - chunk.shard_mem
|
||||
elif chunk.device_type == 'cpu':
|
||||
cuda_demand += chunk.chunk_mem
|
||||
else:
|
||||
raise RuntimeError
|
||||
self._comp_cuda_demand_time += time() - start
|
||||
|
||||
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
|
||||
return cuda_demand, can_evict_chunks
|
||||
|
||||
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
self._compute_idx += 1
|
||||
if self._warmup and self._placement_policy.need_mem_stats:
|
||||
self._compute_list.append(chunks)
|
||||
|
||||
@property
|
||||
def default_device(self):
|
||||
return self._placement_policy.get_default_device()
|
||||
|
||||
def sample_overall_data(self):
|
||||
if self._mem_stats_collector:
|
||||
self._mem_stats_collector.sample_overall_data()
|
||||
|
||||
def record_model_data_volume(self):
|
||||
if self._mem_stats_collector:
|
||||
self._mem_stats_collector.record_model_data_volume()
|
||||
|
||||
@property
|
||||
def chunk_manager(self):
|
||||
return self._chunk_manager
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> Optional[float]:
|
||||
if self._mem_stats_collector:
|
||||
return self._mem_stats_collector.cuda_margin_mem
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
||||
|
||||
@staticmethod
|
||||
def get_default_device(policy_name: str) -> torch.device:
|
||||
return PlacementPolicyFactory.get_default_device(policy_name)
|
Reference in New Issue
Block a user