from dataclasses import dataclass from typing import Dict, List, Optional import torch import torch.distributed as dist from torch.distributed import ProcessGroup from colossalai.elixir.cuda import gpu_device from colossalai.elixir.tensor import FakeTensor from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock from .states import TensorState, ts_update_sanity_check class ChunkFullError(Exception): pass @dataclass class TensorInfo: state: TensorState fake_data: FakeTensor offset: int end: int class Chunk: """Chunk is a type of data structure to store tensors. It allows us to store a sequence of tensors into one continuous memory block. Moreover, Chunk manages the storage of tensors in a distributed way. Normally, a chunk is scattered across its process group. When a tensor in this chunk should be used later, the chunk can be gathered by access_chunk. When the training is done, the chunk can be scattered by reduce_chunk. args: rcache: the memory pool to store replicated chunks chunk_size: the size of the chunk chunk_dtype: the dtype of the chunk process_group: the torch communication group of the chunk temp_device: the device to store the temporary chunk when initializing shard_device: the device to store the shard of the scattered chunk rcache_fused: whether this chunk is fused in rcache without eviction cpu_pin_memory: whether this chunk use cpu pin memory for its shard """ total_count = 0 def __init__( self, rcache: MemoryPool, chunk_size: int, chunk_dtype: torch.dtype, process_group: ProcessGroup, temp_device: Optional[torch.device] = None, shard_device: Optional[torch.device] = None, rcache_fused: bool = False, # whether this chunk is used in ZeRO2 cpu_pin_memory: bool = False # whether this chunk has a permanent copy in cpu ) -> None: self.chunk_id: int = Chunk.total_count Chunk.total_count += 1 # set replicated cache pool self.rcache: MemoryPool = rcache self.chunk_size: int = chunk_size self.chunk_dtype: torch.dtype = chunk_dtype self.utilized_size: int = 0 self.torch_pg: ProcessGroup = process_group self.pg_size: int = dist.get_world_size(self.torch_pg) self.pg_rank: int = dist.get_rank(self.torch_pg) # the chunk size should be divisible by the dp degree assert chunk_size % self.pg_size == 0 self.shard_size: int = chunk_size // self.pg_size self.shard_begin: int = self.shard_size * self.pg_rank self.shard_end: int = self.shard_begin + self.shard_size self.valid_end: int = self.shard_size + 1 # set to an illegal number # notice: release blocks reserved by Pytorch torch.cuda.empty_cache() # rcache block, the global replicated chunk in R cache self.rcb: Optional[TensorBlock] = None self.rcache_fused: bool = rcache_fused self._my_block = None self.is_replica: bool = True # allocate a private block for fused chunks if self.rcache_fused: self._my_block = rcache.get_private_block(chunk_size, chunk_dtype) temp_device: torch.device = temp_device or gpu_device() # chunk_temp is a global chunk, which only exists during building the chunks. # keep all elements to zero self.chunk_temp: Optional[torch.Tensor] = None if rcache_fused: self.chunk_temp = self._my_block.payload torch.zero_(self.chunk_temp) else: self.chunk_temp = torch.zeros(chunk_size, dtype=chunk_dtype, device=temp_device) # configure the init device of the shard # no-offload default: fp16, fp32 -> CUDA # offload default: fp16, fp32 -> CPU shard_device: torch.device = shard_device or torch.device('cpu') pin_flag: bool = cpu_pin_memory and shard_device.type == 'cpu' # chunk.shard is a local chunk # it is desinged to exist permanently self.shard: torch.Tensor = torch.empty(self.shard_size, dtype=chunk_dtype, device=shard_device, pin_memory=pin_flag) # calculate the memory occupation of the chunk and the shard self.chunk_memo: int = self.chunk_size * self.chunk_temp.element_size() self.shard_memo: int = self.chunk_memo // self.pg_size # each tensor is associated with a TensorInfo to track its meta info # (state, shape, offset, end) self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} # the total number of tensors in the chunk self.num_tensors: int = 0 # Record the number of tensors in different states self.tensor_state_cnter: Dict[TensorState, int] = dict() for state in TensorState: self.tensor_state_cnter[state] = 0 # we introduce the paired chunk here # it refers to another chunk having the same parameters # but with different dtype(such as fp16_chunk.paired_chunk -> fp32_chunk self.paired_chunk = None # if this chunk is synchronized with the optimizer, the flag is True self.optim_sync_flag = True # whether to record l2 norm for the gradient clipping calculation self.l2_norm_flag = False self.l2_norm = None # whether it overflows after the reduction self.overflow = False @property def prepared_block(self): return self._my_block @property def is_init(self): return self.chunk_temp is not None @property def in_rcache(self): return self.rcb is not None @property def shard_device(self): return self.shard.device @property def memory_usage(self) -> Dict[str, int]: cuda_memory = 0 cpu_memory = 0 # this chunk is not closed if self.is_init: if self.chunk_temp.device.type == 'cuda': cuda_memory += self.chunk_memo else: cpu_memory += self.chunk_memo # this chunk is on the rcache if self.in_rcache: cuda_memory += self.rcb.memo_occ # calculate the occupation of the chunk shard if self.shard_device.type == 'cuda': cuda_memory += self.shard_memo elif self.shard_device.type == 'cpu': cpu_memory += self.shard_memo else: raise NotImplementedError return dict(cuda=cuda_memory, cpu=cpu_memory) @property def payload(self) -> torch.Tensor: if self.is_init: return self.chunk_temp if self.in_rcache: return self.rcb.payload else: return self.shard @property def shard_move_check(self) -> bool: return not self.in_rcache def _not_compute_number(self): total = 0 state_list = [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE] for state in state_list: total += self.tensor_state_cnter[state] return total @property def scatter_check(self) -> bool: if self.rcache_fused: return False return self._not_compute_number() == self.num_tensors @property def reduce_check(self): return self.tensor_state_cnter[TensorState.READY_FOR_REDUCE] == self.num_tensors def enable_l2_norm_flag(self) -> None: self.l2_norm_flag = True def set_overflow_flag(self, valid_tensor: torch.Tensor) -> None: assert not self.overflow self.overflow = torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item() def set_l2_norm(self, valid_tensor: torch.Tensor) -> None: assert self.l2_norm is None, 'you are calculating the l2 norm twice' chunk_l2_norm = valid_tensor.data.float().norm(2) self.l2_norm = chunk_l2_norm.item()**2 def append_tensor(self, tensor: torch.Tensor): # sanity check assert self.is_init assert tensor.dtype == self.chunk_dtype new_utilized_size = self.utilized_size + tensor.numel() # raise exception when the chunk size is exceeded if new_utilized_size > self.chunk_size: raise ChunkFullError self.chunk_temp[self.utilized_size:new_utilized_size].copy_(tensor.data.flatten()) tensor.data = self.chunk_temp[self.utilized_size:new_utilized_size].view(tensor.shape) fake_data = FakeTensor(tensor.data) # record all the information about the tensor self.num_tensors += 1 tensor_state = TensorState.HOLD self.tensor_state_cnter[tensor_state] += 1 self.tensors_info[tensor] = TensorInfo(state=tensor_state, fake_data=fake_data, offset=self.utilized_size, end=new_utilized_size) self.utilized_size = new_utilized_size def close_chunk(self): # sanity check assert self.is_init # calculate the valid end for each shard if self.utilized_size <= self.shard_begin: self.valid_end = 0 elif self.utilized_size < self.shard_end: self.valid_end = self.utilized_size - self.shard_begin self.__remove_tensors_ptr() self.__update_shard(self.chunk_temp, self.shard) self.is_replica = False self.chunk_temp = None def replicate(self): assert not self.is_replica self.is_replica = True this_shard = self.shard if self.optim_sync_flag else self.__paired_shard() self.__update_replica(self.rcb.payload, this_shard) self.__update_tensors_ptr() def scatter(self): assert not self.rcache_fused assert self.is_replica self.__remove_tensors_ptr() if not self.optim_sync_flag: self.__update_shard(self.rcb.payload, self.shard) self.optim_sync_flag = True self.is_replica = False def reduce(self, always_fp32: bool = False): assert self.is_replica self.__remove_tensors_ptr() if self.pg_size > 1: cast_to_fp32 = False if always_fp32 and self.chunk_dtype != torch.float: cast_to_fp32 = True # cast the payload to fp32 reduce_buffer = self.rcb.payload.to(dtype=torch.float) else: # otherwise, use the same payload reduce_buffer = self.rcb.payload # divide the reduce buffer by the size of the process group reduce_buffer /= self.pg_size # try to use inplace reduce scatter # notice: pytorch does not allow true inplace reduce scatter # because pytorch will allocate a continuous memory space for collective communications shard_buffer = reduce_buffer[self.shard_begin:self.shard_end] dist.reduce_scatter_tensor(shard_buffer, reduce_buffer, group=self.torch_pg) # the result should be moved to payload for norm calculating if cast_to_fp32: calc_buffer = self.rcb.payload[self.shard_begin:self.shard_end] calc_buffer.copy_(shard_buffer) else: # if process group size equals to 1, do not communicate reduce_buffer = self.rcb.payload self.__update_shard(reduce_buffer, self.shard) self.is_replica = False def access_chunk(self, block: Optional[TensorBlock] = None): # sanity check assert not self.is_init assert not self.is_replica if self.rcache_fused: assert block is None self.rcb = self._my_block else: assert block in self.rcache.public_used_blocks assert self.rcb is None self.rcb = block self.replicate() def release_chunk(self) -> TensorBlock: # sanity check assert not self.is_init assert self.is_replica if self.rcache_fused: raise RuntimeError self.scatter() block = self.rcb self.rcb = None return block def update_extra_reduce_info(self, block: Optional[TensorBlock]): if self.rcache_fused: assert block is None block = self._my_block else: assert block is not None buffer = block.payload[self.shard_begin:self.shard_end] valid_tensor = buffer[:self.valid_end] self.set_overflow_flag(valid_tensor) if self.l2_norm_flag: self.set_l2_norm(valid_tensor) def reduce_chunk(self, always_fp32: bool = False, sync: bool = True) -> Optional[TensorBlock]: """Reduce scatter all the gradients. It's an operation done in CUDA. """ # sanity check assert not self.is_init assert self.is_replica self.reduce(always_fp32=always_fp32) self.__update_tensors_state(TensorState.HOLD) # reset the rcb pointer block = self.rcb self.rcb = None if self.rcache_fused: block = None if sync: self.update_extra_reduce_info(block) return block def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: prev_state = self.tensors_info[tensor].state if prev_state == tensor_state: return if ts_update_sanity_check(prev_state, tensor_state): self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: # sanity check assert self.is_replica info = self.tensors_info[tensor] payload = self.rcb.payload payload[info.offset:info.end].copy_(data_slice.data.flatten()) tensor.data = payload[info.offset:info.end].view(tensor.shape) def init_pair(self, friend_chunk: 'Chunk') -> None: if self.paired_chunk is None and friend_chunk.paired_chunk is None: self.paired_chunk = friend_chunk friend_chunk.paired_chunk = self else: assert self.paired_chunk is friend_chunk assert friend_chunk.paired_chunk is self def optim_update(self) -> None: """Update the fp16 chunks via their fp32 chunks. It's used by the optimizer. """ # sanity check assert self.paired_chunk is not None friend_chunk: Chunk = self.paired_chunk assert not friend_chunk.is_replica # gradient and optimizer should be on the same device assert self.shard_device.type == friend_chunk.shard_device.type if self.shard_device.type == 'cuda': self.shard.copy_(friend_chunk.shard) self.optim_sync_flag = True elif self.shard_device.type == 'cpu': # optim_sync_flag is set to False # see shard_move function for more details self.optim_sync_flag = False else: raise NotImplementedError def get_tensors(self) -> List[torch.Tensor]: return list(self.tensors_info.keys()) def get_cpu_copy(self, only_rank_0: bool = False) -> List[torch.Tensor]: assert not self.is_init if self.is_replica: # use the payload directly when being replica temp_buffer = self.rcb.payload else: # otherwise, create a temporary buffer temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device()) # cheat the assertion in __update_replica self.is_replica = True self.__update_replica(temp_buffer, self.shard) self.is_replica = False cpu_copys = [None] * self.num_tensors if not only_rank_0 or self.pg_rank == 0: for i, (t, info) in enumerate(self.tensors_info.items()): t_copy = temp_buffer[info.offset:info.end].view(t.shape).cpu() cpu_copys[i] = t_copy # synchronize dist.barrier() return cpu_copys def load_tensors(self, tensor_list: List[Optional[torch.Tensor]], only_rank_0: bool = False) -> bool: assert not self.is_replica assert not self.is_init temp_buffer = torch.empty(self.chunk_size, dtype=self.chunk_dtype, device=gpu_device()) # cheat the assertion in __update_replica self.is_replica = True self.__update_replica(temp_buffer, self.shard) self.is_replica = False if not only_rank_0 or self.pg_rank == 0: for (_, c_info), load_tensor in zip(self.tensors_info.items(), tensor_list): if load_tensor is None: continue temp_buffer[c_info.offset:c_info.end].copy_(load_tensor.data.flatten()) # synchronize dist.barrier() if only_rank_0: dist.broadcast(temp_buffer, src=0, group=self.torch_pg) # cheat the assertion in __update_shard self.is_replica = True self.__update_shard(temp_buffer, self.shard) self.is_replica = False def __update_replica(self, replica: torch.Tensor, shard: torch.Tensor): assert self.is_replica assert replica.numel() == self.chunk_size assert shard.numel() == self.shard_size buffer = replica[self.shard_begin:self.shard_end] buffer.copy_(shard) dist.all_gather_into_tensor(replica, buffer, group=self.torch_pg) def __update_shard(self, replica: torch.Tensor, shard: torch.Tensor): assert self.is_replica assert replica.numel() == self.chunk_size assert shard.numel() == self.shard_size shard.copy_(replica[self.shard_begin:self.shard_end]) def __paired_shard(self): assert self.paired_chunk is not None, 'chunks should be paired before training' optim_chunk: Chunk = self.paired_chunk assert self.chunk_size == optim_chunk.chunk_size # only be called when optimizer state is in CPU memory # the grad and param should be in the same device assert self.shard_device.type == 'cpu' return optim_chunk.shard.to(gpu_device()) def __remove_tensors_ptr(self) -> None: # sanity check # each tensor should point to its fake data before scatter assert self.is_replica for tensor, info in self.tensors_info.items(): tensor.data = info.fake_data def __update_tensors_ptr(self) -> None: # sanity check # the chunk should be replicated to get the correct pointer assert self.is_replica payload = self.rcb.payload for tensor, info in self.tensors_info.items(): tensor.data = payload[info.offset:info.end].view(tensor.shape) def __update_one_tensor_info(self, tensor_info: TensorInfo, next_state: TensorState): self.tensor_state_cnter[tensor_info.state] -= 1 tensor_info.state = next_state self.tensor_state_cnter[tensor_info.state] += 1 def __update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): for tensor_info in self.tensors_info.values(): if prev_state is None or tensor_info.state == prev_state: self.__update_one_tensor_info(tensor_info, next_state) def __hash__(self) -> int: return self.chunk_id def __lt__(self, other: object) -> bool: return self.chunk_id < other.chunk_id def __eq__(self, other: object) -> bool: return self.chunk_id == other.chunk_id def __repr__(self, detailed: bool = True): if self.is_init: state = 'initialization' elif self.in_rcache: state = 'replicated' else: state = 'scattered' output = [ f'Chunk {self.chunk_id} details: state -> {state}\n', f' length: {self.chunk_size}, dtype: {self.chunk_dtype}, group_size: {self.pg_size}, tensors: {self.num_tensors}\n' f' utilized size: {self.utilized_size}, utilized percentage: {100 * (self.utilized_size / self.chunk_size):.0f}%\n' ] memory_info = self.memory_usage output.append(' memory usage: (cuda -> {}, cpu -> {})\n'.format(memory_info['cuda'], memory_info['cpu'])) def print_tensor(name, tensor, prefix=''): output.append(f'{prefix}{name}: (shape={tensor.shape}, dtype={tensor.dtype}, device={tensor.device})\n') if self.is_init: print_tensor(name='temp', tensor=self.chunk_temp, prefix=' ') if self.in_rcache: print_tensor(name='block', tensor=self.rcb.payload, prefix=' ') if self.shard is not None: print_tensor(name='shard', tensor=self.shard, prefix=' ') if detailed: output.append(' tensor state monitor:\n') for st in TensorState: output.append(' # of {}: {}\n'.format(st, self.tensor_state_cnter[st])) return ''.join(output)