From 6e38eafebec514cd670758a0fb06b37bd01d224e Mon Sep 17 00:00:00 2001 From: hxwang Date: Wed, 15 May 2024 16:51:44 +0800 Subject: [PATCH] [gemini] prefetch chunks --- colossalai/zero/gemini/chunk/chunk.py | 12 ++-- colossalai/zero/gemini/chunk/manager.py | 10 +-- colossalai/zero/gemini/gemini_ddp.py | 4 +- colossalai/zero/gemini/gemini_hook.py | 87 +++++++++++++++++++++++-- 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index cad2622f2..299ea0518 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -357,14 +357,14 @@ class Chunk: else: raise NotImplementedError - def access_chunk(self): + def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" # sanity check assert self.chunk_temp is None - if not self.is_gathered: - self.__gather() + return self.__gather(async_op=async_access) self.__update_tensors_ptr() + return None def release_chunk(self): """Release the usable chunk. It's an operation done in CUDA.""" @@ -498,17 +498,19 @@ class Chunk: def get_tensors(self) -> List[torch.Tensor]: return list(self.tensors_info.keys()) - def __gather(self): + def __gather(self, async_op: bool = False) -> Optional[dist.Work]: if not self.is_gathered: # sanity check assert self.cuda_shard is not None alloc_storage(self.cuda_global_chunk) gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) - dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op) self.cuda_shard = None self.is_gathered = True + return work + return None def __scatter(self): if self.keep_gathered: diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 333a3f224..9cee5223e 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -111,15 +111,16 @@ class ChunkManager: for group_name in self.chunk_groups: self.__close_one_chunk(self.chunk_groups[group_name][-1]) - def access_chunk(self, chunk: Chunk) -> None: + def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: """Make the chunk can be used for calculation.""" if chunk in self.accessed_chunks: return self.__sub_memory_usage(chunk.memory_usage) if chunk.device_type == "cpu": chunk.shard_move(get_accelerator().get_current_device()) - self.__add_accessed_chunk(chunk) + maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access) self.__add_memory_usage(chunk.memory_usage) + return maybe_work def release_chunk(self, chunk: Chunk) -> None: """Scatter the chunk in CUDA.""" @@ -251,10 +252,11 @@ class ChunkManager: for k, v in usage.items(): self.total_mem[k] += v - def __add_accessed_chunk(self, chunk: Chunk): - chunk.access_chunk() + def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]: + maybe_work = chunk.access_chunk(async_access=async_access) self.accessed_chunks.add(chunk) self.accessed_mem += chunk.chunk_mem + return maybe_work def __sub_accessed_chunk(self, chunk: Chunk): chunk.release_chunk() diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c1029097a..21448bdae 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper): chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", enable_gradient_accumulation: bool = False, + max_prefetch: int = 0, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -132,7 +133,6 @@ class GeminiDDP(ModelWrapper): steady_cuda_cap_ratio=steady_cuda_cap_ratio, ) self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() @@ -157,6 +157,8 @@ class GeminiDDP(ModelWrapper): for p in module.parameters(): param_order.append(p) + self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch) + for name, param in module.named_parameters(): self.param2name[param] = name for m_name, m_var in module.named_modules(): diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index 480a14511..7f75f2471 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -1,14 +1,18 @@ +from chunk import Chunk from contextlib import contextmanager from enum import Enum from functools import partial -from typing import List +from typing import Dict, List import torch +import torch.distributed as dist +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator class TrainingPhase(Enum): @@ -16,23 +20,92 @@ class TrainingPhase(Enum): BACKWARD = 1 +DEBUG = True # TODO @botbw: remove + + class GeminiZeROHook(ColoParamOpHook): - def __init__(self, gemini_manager: GeminiManager) -> None: + def __init__( + self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0 + ) -> None: super().__init__() self._gemini_manager = gemini_manager self._chunk_manager = gemini_manager.chunk_manager self._training_phase = TrainingPhase.FORWARD + self._cur_param = None + # param_visited_order might be updated somewhere else + self._param_visited_order = param_order.param_visited_order + self._max_prefetch = max_prefetch + self._async_works: Dict[Chunk, dist.work] = {} - def pre_op(self, params): - params = [p for p in params if not is_ddp_ignored(p)] - chunks = self._chunk_manager.get_chunks(params) + # used by get_prefetch_chunks to track current param + self._cur_param_idx = 0 + + def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]: + chunks_to_prefetch = set() + if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase + self._cur_param_idx += len(all_params) # need to update first + idx = self._cur_param_idx + 1 + # still have params and prefetched chunks don't exceed the limit + while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch: + param = self._param_visited_order[idx] + if is_ddp_ignored(param): + idx += 1 + continue + chunk = self._chunk_manager.get_chunk(param) + chunks_to_prefetch.add(chunk) + idx += 1 + else: + assert self._training_phase == TrainingPhase.BACKWARD + self._cur_param_idx -= len(all_params) + idx = self._cur_param_idx - 1 + chunks_to_prefetch = set() + while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch: + param = self._param_visited_order[idx] + if is_ddp_ignored(param): + idx -= 1 + continue + chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx]) + chunks_to_prefetch.add(chunk) + idx -= 1 + return list(chunks_to_prefetch) + + def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]: + non_prefetched_chunks = [] + for chunk in chunks: + if chunk in self._async_works: + self._async_works[chunk].wait() + del self._async_works[chunk] + else: + non_prefetched_chunks.append(chunk) + return non_prefetched_chunks + + def pre_op(self, all_params): + if DEBUG: # TODO @botbw: remove + idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params)) + mx = max(idxs) + idxs = sorted(map(lambda x: x - mx, idxs)) + assert list(range(len(idxs))) == idxs, f"{idxs=}" + + # deal with current needed chunks + params = [p for p in all_params if not is_ddp_ignored(p)] + all_chunks = self._chunk_manager.get_chunks(params) + chunks_wo_work = self.wait_chunks(all_chunks) for p in params: self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._gemini_manager.sample_overall_data() - self._gemini_manager.adjust_layout(chunks) - for chunk in chunks: + self._gemini_manager.adjust_layout(chunks_wo_work) + + # deal with chunks that are to be async fetched + prefetch_chunks = self.get_prefetch_chunks(all_params) + + # deal with chunks that are to be fetched now + for chunk in chunks_wo_work: self._chunk_manager.access_chunk(chunk) + # deal with chunks that are to be pre fetched TODO @botbw: the order here matters? + for chunk in prefetch_chunks: + self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True) + # record cuda model data of the current OP self._gemini_manager.record_model_data_volume()