diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 9cee5223e..c7bdd5e1f 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -114,7 +114,7 @@ class ChunkManager: def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: - return + return None self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 21448bdae..b75f69a3b 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -133,6 +133,7 @@ class GeminiDDP(ModelWrapper): steady_cuda_cap_ratio=steady_cuda_cap_ratio, ) self.force_outputs_fp32 = force_outputs_fp32 + self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() @@ -157,8 +158,6 @@ class GeminiDDP(ModelWrapper): for p in module.parameters(): param_order.append(p) - self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch) - for name, param in module.named_parameters(): self.param2name[param] = name for m_name, m_var in module.named_modules(): diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 01d9c9d07..82d890975 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -6,16 +6,15 @@ from typing import Dict, List import torch import torch.distributed as dist -from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.logging import DistributedLogger from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager -from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator -from colossalai.logging import DistributedLogger from .chunk import Chunk + class TrainingPhase(Enum): FORWARD = 0 BACKWARD = 1 @@ -23,51 +22,15 @@ class TrainingPhase(Enum): logger = DistributedLogger("gemini_hook") + class GeminiZeROHook(ColoParamOpHook): - def __init__( - self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0 - ) -> None: + def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD - # param_visited_order might be updated somewhere else - self._param_visited_order = param_order.param_visited_order self._max_prefetch = max_prefetch self._async_works: Dict[Chunk, dist.work] = {} - # used by get_prefetch_chunks to track current param - self._cur_param_idx = 0 - - def get_prefetch_chunks(self, all_params: List[ColoParameter], cur_chunks: List[Chunk]) -> List[Chunk]: - chunks_to_prefetch = set() - if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase - self._cur_param_idx += len(all_params) # need to update first - idx = self._cur_param_idx + 1 - # still have params and prefetched chunks don't exceed the limit - while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch: - param = self._param_visited_order[idx] - if is_ddp_ignored(param): - idx += 1 - continue - chunk = self._chunk_manager.get_chunk(param) - if chunk not in cur_chunks: - chunks_to_prefetch.add(chunk) - idx += 1 - else: - self._cur_param_idx -= len(all_params) - idx = self._cur_param_idx - 1 - chunks_to_prefetch = set() - while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch: - param = self._param_visited_order[idx] - if is_ddp_ignored(param): - idx -= 1 - continue - chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) - if chunk not in cur_chunks: - chunks_to_prefetch.add(chunk) - idx -= 1 - print(f"cur id {self._cur_param_idx}") - return list(chunks_to_prefetch) def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: non_prefetched_chunks = [] @@ -80,45 +43,25 @@ class GeminiZeROHook(ColoParamOpHook): non_prefetched_chunks.append(chunk) return non_prefetched_chunks - def pre_op(self, all_params): - # def find_idx(param): - # for i, p in enumerate(self._param_visited_order): - # if param is p: - # return i - # assert False - - # idxs = [find_idx(p) for p in all_params] - # max_id = min(idxs) - # idxs = [i - max_id for i in idxs] - # assert list(range(len(idxs))) == sorted(idxs), f'{idxs}' - - # deal with current needed chunks - params = [p for p in all_params if not is_ddp_ignored(p)] + def pre_op(self, params): + params = [p for p in params if not is_ddp_ignored(p)] all_chunks = self._chunk_manager.get_chunks(params) - chunks_need_to_fetch_sync = tuple(self.wait_chunks(all_chunks)) + # wait for prefetched chunks, filter those are not prefetched + chunks_fetch_sync = tuple(self.wait_chunks(all_chunks)) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks_need_to_fetch_sync) - - # deal with chunks that are to be async fetched - chunks_can_be_fetch_async = self.get_prefetch_chunks(all_params=all_params, cur_chunks=chunks_need_to_fetch_sync) - - print(f"cur_chunks {' '.join([str(x.count_id) for x in chunks_need_to_fetch_sync])}, prefetch {' '.join([str(x.count_id) for x in chunks_can_be_fetch_async])}") - # deal with chunks that are to be fetched now - for chunk in chunks_need_to_fetch_sync: + self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0) + # fetch the rest chunks synchronously + for chunk in chunks_fetch_sync: self._chunk_manager.access_chunk(chunk) - - # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? - for chunk in chunks_can_be_fetch_async: - if chunk in self._async_works: - continue + chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch) + for chunk in chunks_fetch_async: maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) if maybe_work is not None: - print(f"prefetch {chunk.count_id}") self._async_works[chunk] = maybe_work - # record cuda model data of the current OP + # record cuda model data of the current OP, including memory for prefetched chunks self._gemini_manager.record_model_data_volume() def post_op(self, params): diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index 150932e3d..0362d6523 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -6,7 +6,7 @@ import torch from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector, MemStats -from .placement_policy import PlacementPolicyFactory +from .placement_policy import PlacementPolicy, PlacementPolicyFactory class GeminiManager: @@ -91,13 +91,13 @@ class GeminiManager: self._warmup = False self.reset_attributes() - def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: + def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: """Adjust the layout of stateful tensors according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE start = time() - self._record_chunks_order(chunks) + self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) self._layout_time += time() - start @@ -133,9 +133,9 @@ class GeminiManager: can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() return cuda_demand, can_evict_chunks - def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: + def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None: self._compute_idx += 1 - if self._warmup and self._placement_policy.need_mem_stats: + if self._warmup and (self._placement_policy.need_mem_stats or record_anyway): self._compute_list.append(chunks) def sample_overall_data(self): @@ -156,6 +156,18 @@ class GeminiManager: return self._mem_stats_collector.cuda_margin_mem return None + @property + def compute_list(self) -> List[Tuple[Chunk, ...]]: + return self._compute_list + + @property + def compute_idx(self) -> int: + return self._compute_idx + + @property + def placement_policy(self) -> PlacementPolicy: + return self._placement_policy + @property def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 388999549..452687b7d 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -33,6 +33,10 @@ class PlacementPolicy(ABC): ) -> None: raise NotImplementedError + @abstractmethod + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + raise NotImplementedError + class StaticPlacementPolicy(PlacementPolicy): def __init__( @@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy): self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + prefetch = [] + for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)): + for chunk in self.chunk_manager.compute_list[i]: + if len(prefetch) >= max_prefetch: + break + if chunk not in prefetch: + prefetch.append(chunk) + if len(prefetch) >= max_prefetch: + break + return prefetch + class AutoPlacementPolicy(PlacementPolicy): need_mem_stats: bool = True @@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy): else: grads_device_map[p] = torch.device("cpu") + def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: + return [] # TODO @botbw: implement prefetching for auto + class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = {