[gemini] prefetch chunks

This commit is contained in:
hxwang 2024-05-15 16:51:44 +08:00
parent 785cd9a9c9
commit 6e38eafebe
4 changed files with 96 additions and 17 deletions

View File

@ -357,14 +357,14 @@ class Chunk:
else: else:
raise NotImplementedError raise NotImplementedError
def access_chunk(self): def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" """Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
# sanity check # sanity check
assert self.chunk_temp is None assert self.chunk_temp is None
if not self.is_gathered: if not self.is_gathered:
self.__gather() return self.__gather(async_op=async_access)
self.__update_tensors_ptr() self.__update_tensors_ptr()
return None
def release_chunk(self): def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA.""" """Release the usable chunk. It's an operation done in CUDA."""
@ -498,17 +498,19 @@ class Chunk:
def get_tensors(self) -> List[torch.Tensor]: def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys()) return list(self.tensors_info.keys())
def __gather(self): def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
if not self.is_gathered: if not self.is_gathered:
# sanity check # sanity check
assert self.cuda_shard is not None assert self.cuda_shard is not None
alloc_storage(self.cuda_global_chunk) alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
self.cuda_shard = None self.cuda_shard = None
self.is_gathered = True self.is_gathered = True
return work
return None
def __scatter(self): def __scatter(self):
if self.keep_gathered: if self.keep_gathered:

View File

@ -111,15 +111,16 @@ class ChunkManager:
for group_name in self.chunk_groups: for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1]) self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None: def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk can be used for calculation.""" """Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks: if chunk in self.accessed_chunks:
return return
self.__sub_memory_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu": if chunk.device_type == "cpu":
chunk.shard_move(get_accelerator().get_current_device()) chunk.shard_move(get_accelerator().get_current_device())
self.__add_accessed_chunk(chunk) maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
return maybe_work
def release_chunk(self, chunk: Chunk) -> None: def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.""" """Scatter the chunk in CUDA."""
@ -251,10 +252,11 @@ class ChunkManager:
for k, v in usage.items(): for k, v in usage.items():
self.total_mem[k] += v self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk): def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
chunk.access_chunk() maybe_work = chunk.access_chunk(async_access=async_access)
self.accessed_chunks.add(chunk) self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem self.accessed_mem += chunk.chunk_mem
return maybe_work
def __sub_accessed_chunk(self, chunk: Chunk): def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk() chunk.release_chunk()

View File

@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper):
chunk_init_device: torch.device = torch.device("cpu"), chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static", placement_policy: str = "static",
enable_gradient_accumulation: bool = False, enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement
@ -132,7 +133,6 @@ class GeminiDDP(ModelWrapper):
steady_cuda_cap_ratio=steady_cuda_cap_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio,
) )
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list() self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list() self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict() self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@ -157,6 +157,8 @@ class GeminiDDP(ModelWrapper):
for p in module.parameters(): for p in module.parameters():
param_order.append(p) param_order.append(p)
self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch)
for name, param in module.named_parameters(): for name, param in module.named_parameters():
self.param2name[param] = name self.param2name[param] = name
for m_name, m_var in module.named_modules(): for m_name, m_var in module.named_modules():

View File

@ -1,14 +1,18 @@
from chunk import Chunk
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import List from typing import Dict, List
import torch import torch
import torch.distributed as dist
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator
class TrainingPhase(Enum): class TrainingPhase(Enum):
@ -16,23 +20,92 @@ class TrainingPhase(Enum):
BACKWARD = 1 BACKWARD = 1
DEBUG = True # TODO @botbw: remove
class GeminiZeROHook(ColoParamOpHook): class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None: def __init__(
self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0
) -> None:
super().__init__() super().__init__()
self._gemini_manager = gemini_manager self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
self._cur_param = None
# param_visited_order might be updated somewhere else
self._param_visited_order = param_order.param_visited_order
self._max_prefetch = max_prefetch
self._async_works: Dict[Chunk, dist.work] = {}
def pre_op(self, params): # used by get_prefetch_chunks to track current param
params = [p for p in params if not is_ddp_ignored(p)] self._cur_param_idx = 0
chunks = self._chunk_manager.get_chunks(params)
def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]:
chunks_to_prefetch = set()
if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase
self._cur_param_idx += len(all_params) # need to update first
idx = self._cur_param_idx + 1
# still have params and prefetched chunks don't exceed the limit
while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch:
param = self._param_visited_order[idx]
if is_ddp_ignored(param):
idx += 1
continue
chunk = self._chunk_manager.get_chunk(param)
chunks_to_prefetch.add(chunk)
idx += 1
else:
assert self._training_phase == TrainingPhase.BACKWARD
self._cur_param_idx -= len(all_params)
idx = self._cur_param_idx - 1
chunks_to_prefetch = set()
while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch:
param = self._param_visited_order[idx]
if is_ddp_ignored(param):
idx -= 1
continue
chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx])
chunks_to_prefetch.add(chunk)
idx -= 1
return list(chunks_to_prefetch)
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return non_prefetched_chunks
def pre_op(self, all_params):
if DEBUG: # TODO @botbw: remove
idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params))
mx = max(idxs)
idxs = sorted(map(lambda x: x - mx, idxs))
assert list(range(len(idxs))) == idxs, f"{idxs=}"
# deal with current needed chunks
params = [p for p in all_params if not is_ddp_ignored(p)]
all_chunks = self._chunk_manager.get_chunks(params)
chunks_wo_work = self.wait_chunks(all_chunks)
for p in params: for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data() self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks) self._gemini_manager.adjust_layout(chunks_wo_work)
for chunk in chunks:
# deal with chunks that are to be async fetched
prefetch_chunks = self.get_prefetch_chunks(all_params)
# deal with chunks that are to be fetched now
for chunk in chunks_wo_work:
self._chunk_manager.access_chunk(chunk) self._chunk_manager.access_chunk(chunk)
# deal with chunks that are to be pre fetched TODO @botbw: the order here matters?
for chunk in prefetch_chunks:
self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True)
# record cuda model data of the current OP # record cuda model data of the current OP
self._gemini_manager.record_model_data_volume() self._gemini_manager.record_model_data_volume()