diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 693899c01..8ad5b8ba2 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -120,7 +120,7 @@ class ColoDDPV2(ColoDDP): def _setup_grads_ptr(self): for p in self.module.parameters(): - if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad: + if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad: p.grad = None else: p.grad = p.data @@ -154,7 +154,7 @@ class ColoDDPV2(ColoDDP): chunk = self.chunk_manager.get_chunk(p) reduced = self.chunk_manager.reduce_chunk(chunk) self.chunk_manager.release_chunk(chunk) - if reduced and not chunk.is_free: + if reduced and not chunk.is_empty: self.overflow_counter += chunk.has_inf_or_nan self.chunk_manager.move_chunk(chunk, self.grads_device[p]) return empty_grad diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index 94da00157..24243bf42 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -38,6 +38,16 @@ class ChunkFullError(Exception): class Chunk: + """ + A chunk is a contiguous memory space which contains multiple tensors. + + Args: + chunk_size (int): the number of elements in a chunk + src_rank (int): the process which owns the chunk + dtype (torch.dtype): the data type of the chunk + init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU. + """ + def __init__(self, chunk_size: int, src_rank: int, @@ -51,17 +61,34 @@ class Chunk: self.dtype = dtype self.device = init_device or get_current_device() self.data = torch.empty(chunk_size, dtype=dtype, device=self.device) + + # we only keep the chunk in full in the process by which the tensor is owned if not self.is_src_rank: self.data.storage().resize_(0) + + # each tensor is associated with a TensorInfo to track meta info self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} self.mem = self.size * self.data.element_size() def append(self, tensor: torch.Tensor) -> None: + """ + Add a tensor to the chunk. + + Args: + tensor (torch.Tensor): a tensor to be added to the chunk + """ assert tensor.dtype == self.dtype new_utilized_size = self.utilized_size + tensor.numel() + + # raise exception when the chunk size is exceeded if new_utilized_size > self.size: raise ChunkFullError + + # set tensor state tensor_state = TensorState.FREE + + # if the process owns the rank, then copy the tensor to its chunk buffer + # otherwise set its storage size to 0 to reduce memory consumption if self.is_src_rank: self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1)) tensor_state = TensorState.HOLD @@ -72,6 +99,9 @@ class Chunk: self.utilized_size = new_utilized_size def release(self) -> None: + """ + Release the memory space on processes which do not own the chunk. + """ if not self.is_src_rank: self.data.storage().resize_(0) self._update_tensors_state(TensorState.FREE) @@ -86,19 +116,38 @@ class Chunk: tensor_info.state = next_state def access(self) -> None: + """ + Broadcast the chunk to synchronize the tensors across data parallel processes. + """ + # recover the chunk on non-owner processes + # and broadcast the chunk from the source to all processes if not self.is_src_rank: self.data.storage().resize_(self.size) self.data.data = self.data.to(get_current_device()) dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) + + # update tensor meta info self._update_tensors_ptr() if not self.is_src_rank: self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) def move_device(self, device: torch.device) -> None: + """ + Move the chunk to a target device. + + Args: + device (torch.device): the target device for data movement. + """ self.data.data = self.data.to(device) self._update_tensors_ptr() def reduce(self, is_all_reduce: bool = False) -> None: + """ + Reduce or all-reduce the chunk. + + Args: + is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false. + """ self.data.data = self.data.to(get_current_device()) if is_all_reduce: dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) @@ -108,6 +157,13 @@ class Chunk: self._update_tensors_state(TensorState.HOLD) def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + """ + Make a transition of the tensor into the next state. + + Args: + tensor (torch.Tensor): a torch Tensor object. + tensor_state (TensorState): the target state for transition. + """ assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE' # As the gradient hook can be triggered either before or after post-backward # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce @@ -123,12 +179,22 @@ class Chunk: self.tensors_info[tensor].state = tensor_state def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the chunk + """ tensor_info = self.tensors_info[tensor] self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1)) tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape) @property def can_release(self) -> bool: + """ + Check whether the chunk can be released. + """ for tensor_info in self.tensors_info.values(): if tensor_info.state != TensorState.HOLD: return False @@ -136,6 +202,9 @@ class Chunk: @property def can_move_device(self) -> bool: + """ + Check whether the chunk can be moved across devices. + """ for tensor_info in self.tensors_info.values(): if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE): return False @@ -143,26 +212,38 @@ class Chunk: @property def can_reduce(self) -> bool: + """ + Check whether the chunk can be reduced. + """ for tensor_info in self.tensors_info.values(): if tensor_info.state != TensorState.READY_FOR_REDUCE: return False return True @property - def is_free(self) -> bool: + def is_empty(self) -> bool: + """ + Check whether the chunk is empty. + """ return self.data.storage().size() == 0 def __repr__(self) -> str: - return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_free}, tensor states={[info.state.name for info in self.tensors_info.values()]}' + return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}' @property def has_inf_or_nan(self) -> bool: + """ + Check if the chunk has inf or nan values. + """ return torch.isinf(self.data[:self.utilized_size]).any().item() or \ torch.isnan(self.data[:self.utilized_size]).any().item() def copy_(self, dest_chunk: 'Chunk'): - assert not self.is_free - assert not dest_chunk.is_free + """ + Copy the data of this chunk to a destination chunk. + """ + assert not self.is_empty + assert not dest_chunk.is_empty assert self.size == dest_chunk.size assert self.utilized_size == dest_chunk.utilized_size self.data.copy_(dest_chunk.data) @@ -170,6 +251,9 @@ class Chunk: @property def device_type(self) -> str: + """ + Get the device type of the chunk. + """ return self.data.device.type def __hash__(self) -> int: @@ -183,6 +267,14 @@ class Chunk: class ChunkManager: + """ + A manager class to manipulate the tensors in chunks. + + Args: + chunk_size (int): the size of a chunk. + enable_distributed_storage (bool): optional, allow for distributed storage of a chunk. The default is false. + init_device (torch.device): optional, the device on which the chunk is initialized. The default is None. + """ def __init__(self, chunk_size: Optional[int], @@ -201,54 +293,89 @@ class ChunkManager: self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0} def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None: + """ + Append a tensor to a chunk. + + Args: + tensor (torch.Tensor): a tensor to append to the chunk. + group_name (str): the name of the chunk group. + """ assert tensor not in self.tensor_chunk_map if self.chunk_size is not None and tensor.numel() > self.chunk_size: raise ValueError( f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})') if group_name not in self.chunk_groups: self.chunk_groups[group_name] = deque() + try: + # append the tensor to the last chunk self.chunk_groups[group_name][-1].append(tensor) except (IndexError, ChunkFullError): + # the except statement will be triggered when there is no chunk or + # the last chunk in the chunk group is full + # this will create a new chunk and allocate this chunk to its corresponding process chunk_size = self.chunk_size or tensor.numel() src_rank = self._get_next_src_rank(group_name) chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device) + if self.enable_distributed_storage and self.chunk_size is None: self.rank_load[group_name][src_rank] += chunk_size + self.chunk_groups[group_name].append(chunk) chunk.append(tensor) - if not chunk.is_free: + if not chunk.is_empty: self.total_mem[chunk.device_type] += chunk.mem self.tensor_chunk_map[tensor] = self.chunk_groups[group_name][-1] if not self.enable_distributed_storage: + # as distributed storage is not enabled, there is no need to broadcast + # chunks, thus we set these chunks as accessed self.accessed_chunks.add(self.chunk_groups[group_name][-1]) def _get_next_src_rank(self, group_name: str) -> int: if not self.enable_distributed_storage: + # the chunk is owned by the current rank if no distributed storage is enabled return gpc.get_local_rank(ParallelMode.DATA) if self.chunk_size is None: if group_name not in self.rank_load: self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64) + + # the process owning the tensor will be the process with the smallest number of elements src_rank = torch.argmin(self.rank_load[group_name]).item() else: + # chunk is owned by processes in a round-robin fashion chunk_idx = len(self.chunk_groups[group_name]) src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) return src_rank def access_chunk(self, chunk: Chunk) -> None: + """ + Synchronize the chunks via broadcast. + + Args: + chunk (Chunk): the chunk to synchronize. + """ if chunk in self.accessed_chunks: if chunk.device_type != 'cuda': self.total_mem[chunk.device_type] -= chunk.mem chunk.move_device(get_current_device()) self.total_mem[chunk.device_type] += chunk.mem return - if not chunk.is_free: + if not chunk.is_empty: + # as tensor is moved to the target device + # the memory consumption of the original device is reduced self.total_mem[chunk.device_type] -= chunk.mem chunk.access() self.accessed_chunks.add(chunk) self.total_mem[chunk.device_type] += chunk.mem def release_chunk(self, chunk: Chunk) -> None: + """ + Release the memory space of a chunk. + + Args: + chunk (Chunk): the chunk to release memory space + """ + if not self.enable_distributed_storage: return if chunk not in self.accessed_chunks: @@ -256,22 +383,44 @@ class ChunkManager: if chunk.can_release: chunk.release() self.accessed_chunks.remove(chunk) - if chunk.is_free: + if chunk.is_empty: + # update the memory consumption after releasing self.total_mem[chunk.device_type] -= chunk.mem def move_chunk(self, chunk: Chunk, device: torch.device) -> None: + """ + Move the chunk to the target device. + + Args: + chunk (Chunk): the chunk to move to target device + device (torch.device): target device + """ if chunk.data.device == device: return - if chunk.can_move_device and not chunk.is_free: + if chunk.can_move_device and not chunk.is_empty: self.total_mem[chunk.device_type] -= chunk.mem chunk.move_device(device) self.total_mem[chunk.device_type] += chunk.mem def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None: + """ + Transit tensor state according to pre-defined state machine. + + Args: + tensor (torch.Tensor): the tensor for state transititon + state (TensorState): next tensor state for transtition + """ chunk = self.tensor_chunk_map[tensor] chunk.tensor_trans_state(tensor, state) def reduce_chunk(self, chunk: Chunk) -> bool: + """ + Reduce or all reduce the chunk. If enable_distributed_storage is true, all-reduce is used. + Otherwise, this method uses reduce. + + Args: + chunk (Chunk): the chunk for reduction. + """ if not chunk.can_reduce: return False self.total_mem[chunk.device_type] -= chunk.mem @@ -280,16 +429,39 @@ class ChunkManager: return True def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None: + """ + Copy data to the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data (torch.Tensor): the tensor to be copied to the chunk + """ chunk = self.tensor_chunk_map[tensor] chunk.copy_tensor_to_chunk_slice(tensor, data) def get_chunk(self, tensor: torch.Tensor) -> Chunk: + """ + Return the chunk owning the tensor. + + Args: + tensor (torch.Tensor): a torch tensor object + """ return self.tensor_chunk_map[tensor] def add_lazy_release_tensors(self, tensors: List[torch.Tensor]) -> None: + """ + Add tensors to the buffer for lazy release. + + Args: + tensors (List[torch.Tensor]): the tensors to be released lazily + """ self.lazy_release_tensors.extend(tensors) def exec_lazy_release(self) -> None: + """ + Execute release for tensors added to the lazy release buffer. + """ + for chunk in self.get_chunks(self.lazy_release_tensors): self.release_chunk(chunk) self.lazy_release_tensors.clear() @@ -305,6 +477,13 @@ class ChunkManager: @staticmethod def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float: + """ + Calculate the utilization rate of a chunk. + + Args: + chunk_size (int): the size of a chunk + params_numel (List[int]): the list of integers representing the number of elements of parameters + """ assert len(params_numel) > 0 total_size = 0 total_utilized_size = 0 @@ -323,6 +502,17 @@ class ChunkManager: search_range: int, n_grids: int, min_chunk_size: Optional[int] = None) -> int: + """ + Search for the chunk size for optimal chunk utilization. + + Args: + module (torch.nn.Module): a torch module object + search_range (int): the range of chunk size to search. The actual search range will be from + max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + search_range. + n_grids (int): the number of intervals in the search range + min_chunk_size (int): optional, the minimum size for a chunk. The default is None. + + """ assert search_range % n_grids == 0 # TODO(ver217): sort params and filter unused ones params_numel = [p.numel() for p in module.parameters()] @@ -342,11 +532,24 @@ class ChunkManager: return best_chunk_size def copy_chunk_group(self, dest_group_name: str, src_group_name: str): + """ + Copy chunk data from one group to another group. + + Args: + dest_group_name (str): the destination group which receives the copied data + src_group_name (str): the source group which provides the data to copy + """ for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]): - if not dest_chunk.is_free: + if not dest_chunk.is_empty: dest_chunk.copy_(src_chunk) def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]: + """ + Get all chunks owning the input tensors. + + Args: + tensors (Iterable[torch.Tensor]): the tensors used to look for chunks + """ chunks = [] for tensor in tensors: chunk = self.get_chunk(tensor) diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index 36209544d..b9208bace 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -64,7 +64,7 @@ class ZeroOptimizer(ColossalaiOptimizer): def _update_params_ptr(self): for group in self.optim.param_groups: for p in group['params']: - if not self.module.chunk_manager.get_chunk(p).is_free: + if not self.module.chunk_manager.get_chunk(p).is_empty: p.data = self.fp16_param_to_fp32_param[p] else: assert p.grad is None