From c577ed016eea370f4310b0b308b4ad7cbd7392e3 Mon Sep 17 00:00:00 2001 From: HELSON Date: Tue, 9 Aug 2022 16:39:48 +0800 Subject: [PATCH] [zero] add AgChunk (#1417) --- colossalai/gemini/ag_chunk.py | 366 ++++++++++++++++++++++++++++++++++ 1 file changed, 366 insertions(+) create mode 100644 colossalai/gemini/ag_chunk.py diff --git a/colossalai/gemini/ag_chunk.py b/colossalai/gemini/ag_chunk.py new file mode 100644 index 000000000..62d19b75d --- /dev/null +++ b/colossalai/gemini/ag_chunk.py @@ -0,0 +1,366 @@ +import torch +import torch.distributed as dist +from typing import Optional, Dict + +from colossalai.utils import get_current_device +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkFullError, \ + free_storage, alloc_storage + + +class AgChunk: + def __init__(self, + chunk_size: int, + process_group: ColoProcessGroup, + dtype: torch.dtype, + init_device: Optional[torch.device] = None, + keep_gathered: bool = False, + pin_memory: bool = False) -> None: + """ + Chunk: A container owning a piece of contiguous memory space for tensors + AgChunk is a kind of chunk, which uses all-gather operation to gather the whole chunk. + This kind of chunk is exclusively used for DDP and ZeRO DDP. + It is designed to make the full use of communication and PCIE bandwidth. + + Args: + chunk_size (int): the number of elements in a chunk + process_group (ColoProcessGroup): the process group of this chunk + dtype (torch.dtype): the data type of the chunk + init_device (torch.device): optional, the device where the tensor is initialized + The default value is None, which is the current GPU + keep_gathered (bool): optional, if True, this chunk is always gathered in CUDA memory + pin_memory (bool): optional, if True, this chunk always has a shard copy in pinned CPU memory + """ + + self.chunk_size = chunk_size + self.utilized_size = 0 + # Here, we use torch process group, + # since ColoProcessGroup might get deprecated soon + self.torch_pg = process_group.dp_process_group + self.pg_size = dist.get_world_size(self.torch_pg) + self.pg_rank = dist.get_rank(self.torch_pg) + + # the chunk size should be able to be divied by the size of GPU + assert chunk_size % self.pg_size == 0 + self.shard_size = chunk_size // self.pg_size + self.shard_begin = self.shard_size * self.pg_rank + self.shard_end = self.shard_begin + self.shard_size + + self.dtype = dtype + device = init_device or get_current_device() + self.chunk_temp = torch.empty(chunk_size, dtype=dtype, device=device) + self.chunk_total = None # we force chunk_total located in CUDA + self.cuda_shard = None # using two attributes for the better interpretation + self.cpu_shard = None + self.is_gathered = True + + self.chunk_mem = self.chunk_size * self.chunk_temp.element_size() + self.shard_mem = self.chunk_mem // self.pg_size + + # each tensor is associated with a TensorInfo to track meta info + self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} + # the total number of all tensors + self.num_tensors = 0 + # monitor the states of all tensors + self.tensors_state_monitor: Dict[TensorState, int] = dict() + for state in TensorState: + self.tensors_state_monitor[state] = 0 + + # some chunks can keep gathered all the time + # so their computation patterns are the same as that of the parameters in DDP + self.keep_gathered = keep_gathered + + # if pin_memory is True, we allocate a piece of CPU pin-memory + # for it all the time + self.pin_memory = pin_memory + + # we introduce the paired chunk here + # it refers to another chunk having the same parameters + # but with different dtype(such as fp16_chunk.mapping_chunk -> fp32_chunk + self.paired_chunk = None + # if the the gradient of this chunk is reduced, the flag is True + # so the flag is False for unused parameters + self.grad_reduced_flag = False + # if this chunk is synchronized with the optimizer, the flag is True + self.optim_sync_flag = True + # if the cpu_shard has been visited during the training step, the flag is True + self.cpu_vis_flag = False + + @property + def memory_usage(self): + cuda_memory = 0 + cpu_memory = 0 + + if self.chunk_temp is not None: + # this chunk is not closed + if self.chunk_temp.device.type == 'cuda': + cuda_memory += self.chunk_mem + else: + cpu_memory += self.chunk_mem + else: + if self.is_gathered: + cuda_memory += self.chunk_mem + if self.cuda_shard is not None: + cuda_memory += self.shard_mem + if self.cpu_shard is not None: + cpu_memory += self.shard_mem + + return dict(cuda=cuda_memory, cpu=cpu_memory) + + @property + def device_type(self): + if self.chunk_temp is not None: + return self.chunk_temp.device.type + else: + if self.chunk_total is not None: + return 'cuda' + elif self.cuda_shard is not None: + return 'cuda' + else: + return 'cpu' + + def append_tensor(self, tensor: torch.Tensor): + """Add a tensor to the chunk. + + Args: + tensor (torch.Tensor): a tensor to be added to the chunk + """ + # sanity check + assert self.chunk_temp is not None + assert tensor.dtype == self.dtype + + new_utilized_size = self.utilized_size + tensor.numel() + # raise exception when the chunk size is exceeded + if new_utilized_size > self.chunk_size: + raise ChunkFullError + + self.chunk_temp[self.utilized_size: new_utilized_size].copy_(tensor.flatten()) + assert type(self.chunk_temp) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" + tensor.data = self.chunk_temp[self.utilized_size: new_utilized_size].view(tensor.shape) + + # record all the information about the tensor + self.num_tensors += 1 + tensor_state = TensorState.HOLD + self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) + self.tensors_state_monitor[tensor_state] += 1 + self.utilized_size = new_utilized_size + + def close_chunk(self, shard_dev: torch.device): + """Close the chunk. Any tensor can't be appended to a closed chunk. + """ + # sanity check + assert self.chunk_temp is not None + + if self.chunk_temp.device.type == 'cpu': + self.chunk_total = self.chunk_temp.to(get_current_device()) + else: + self.chunk_total = self.chunk_temp + self.chunk_temp = None + + self.__scatter() + + if self.pin_memory or shard_dev.type == 'cpu': + self.cpu_shard = torch.empty(self.shard_size, + dtype=self.dtype, + pin_memory=self.pin_memory) + self.cpu_shard.copy_(self.cuda_shard) + self.cpu_vis_flag = True # cpu_shard has been visited + + if shard_dev.type == 'cpu': + self.cuda_shard = None + + def shard_move(self, device: torch.device, force_copy: bool = False): + # sanity check + assert not self.is_gathered + # when the current chunk is not synchronized with the optimizer + # just use another way for the movement + if not self.optim_sync_flag: + assert device.type == 'cuda', "each chunk should first be moved to CUDA" + self.__paired_shard_move() + self.optim_sync_flag = True + return + + if device.type == 'cuda': + assert device == get_current_device(), "can't move chunk to another device" + + if self.cuda_shard: + return + + self.cuda_shard = self.cpu_shard.to(get_current_device()) + + if not self.pin_memory: + self.cpu_shard = None + elif device.type == 'cpu': + if self.cuda_shard is None: + return + + if self.pin_memory: + if force_copy or not self.cpu_vis_flag: + self.cpu_shard.copy_(self.cuda_shard) + # if cpu_shard has been visited + # copy operation is not need + else: + self.cpu_shard = self.cuda_shard.cpu() + self.cpu_vis_flag = True + self.cuda_shard = None + else: + raise NotImplementedError + + def access_chunk(self): + """Make the chunk usable for the parameters inside it. + It is an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if not self.is_gathered: + self.__gather() + self.__update_tensors_ptr() + + def release_chunk(self): + """Release the usable chunk. + It is an operation done in CUDA. + """ + # sanity check + assert self.chunk_temp is None + + if self.is_gathered: + self.__scatter() + + def reduce(self): + """Reduce scatter all the gradients. + It is an operation done in CUDA. + """ + # sanity check + assert self.is_gathered + + if self.pg_size == 1: + # tricky code here + # just move chunk_total to cuda_shard + # the communication is not necessary + self.__scatter() + elif self.keep_gathered: + # we use all-reduce here + dist.all_reduce(self.chunk_total, group=self.torch_pg) + else: + self.cuda_shard = torch.empty( + self.shard_size, dtype=self.dtype, device=get_current_device()) + + input_list = list(torch.chunk(self.chunk_total, chunks=self.pg_size, dim=0)) + dist.reduce_scatter(self.cuda_shard, input_list, self.torch_pg) + + free_storage(self.chunk_total) + self.is_gathered = False + self.__update_tensors_state(TensorState.HOLD) + self.grad_reduced_flag = True + + def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: + """ + Make a transition of the tensor into the next state. + + Args: + tensor (torch.Tensor): a torch Tensor object. + tensor_state (TensorState): the target state for transition. + """ + + # As the gradient hook can be triggered either before or after post-backward + # tensor's state can be compute -> hold_after_bwd -> ready_for_reduce + # or compute -> ready_for_reduce -> hold_after_bwd + # the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd + # this function only apply valid state transformation + # invalid calls will be ignored and nothing changes + if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: + # print( + # f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' + # ) + return + self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) + + def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Copy data slice to the memory space indexed by the input tensor in the chunk. + + Args: + tensor (torch.Tensor): the tensor used to retrive meta information + data_slice (torch.Tensor): the tensor to be copied to the chunk + """ + # sanity check + assert self.is_gathered + + tensor_info = self.tensors_info[tensor] + self.chunk_total[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) + tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) + + @property + def can_release(self) -> bool: + return self.tensors_state_monitor[TensorState.HOLD] == self.num_tensors + + @property + def can_reduce(self): + return self.tensors_state_monitor[TensorState.READY_FOR_REDUCE] == self.num_tensors + + def __gather(self): + if not self.is_gathered: + # sanity check + assert self.cuda_shard is not None + + if self.pg_size == 1: + self.chunk_total = self.cuda_shard + else: + alloc_storage(self.chunk_total) + gather_list = list(torch.chunk( + input=self.chunk_total, chunks=self.pg_size, dim=0)) + dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) + + self.cuda_shard = None + self.is_gathered = True + + def __scatter(self): + if self.keep_gathered: + return + + if self.is_gathered: + # sanity check + assert self.cuda_shard is None + + self.cuda_shard = torch.empty(self.shard_size, + dtype=self.dtype, + device=self.chunk_total.device) + + self.cuda_shard.copy_(self.chunk_total[self.shard_begin: self.shard_end]) + + free_storage(self.chunk_total) + self.is_gathered = False + + def __paired_shard_move(self): + assert self.paired_chunk is not None, "chunks should be paired before training" + optim_chunk = self.paired_chunk + assert self.chunk_size == optim_chunk.chunk_size + + # only be called when optimizer state is in CPU memory + # the grad and param should be in the same device + assert self.cuda_shard is None + temp = optim_chunk.cpu_shard.to(get_current_device()) + # avoid to transform FP32 in CPU + self.cuda_shard = temp.to(self.dtype) + + if not self.pin_memory: + self.cpu_shard = None + + def __update_tensors_ptr(self) -> None: + # sanity check + assert self.is_gathered + assert type(self.chunk_total) == torch.Tensor + + for tensor, tensor_info in self.tensors_info.items(): + tensor.data = self.chunk_total[tensor_info.offset:tensor_info.end].view(tensor.shape) + + def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): + self.tensors_state_monitor[tensor_info.state] -= 1 + tensor_info.state = next_state + self.tensors_state_monitor[tensor_info.state] += 1 + + def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): + for tensor_info in self.tensors_info.values(): + if prev_state is None or tensor_info.state == prev_state: + self.__update_one_tensor_info(tensor_info, next_state)