diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index 0e6bc4ecd..61b15c8e1 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -14,7 +14,6 @@ class EvictionStrategy(Enum): DATASET = 2 - class CachedParamMgr(torch.nn.Module): """ Manage Embedding Weights on CPU and CUDA memory uses a software cache. @@ -64,8 +63,7 @@ class CachedParamMgr(torch.nn.Module): # cache_row_idx -> frequency, freq of the cache rows. # classic lfu cache. evict the minimal freq value row in cuda cache. self.register_buffer("freq_cnter", - torch.empty(self.cuda_row_num, - device=torch.cuda.current_device(), + torch.empty(self.cuda_row_num, device=torch.cuda.current_device(), dtype=torch.long).fill_(sys.maxsize), persistent=False) @@ -82,14 +80,14 @@ class CachedParamMgr(torch.nn.Module): """ if self._evict_strategy == EvictionStrategy.LFU: # find the minimal evict_num freq entries in cached_idx_map - _,evict_gpu_row_idxs = torch.topk(self.freq_cnter,evict_num,largest=False) + _, evict_gpu_row_idxs = torch.topk(self.freq_cnter, evict_num, largest=False) return evict_gpu_row_idxs elif self._evict_strategy == EvictionStrategy.DATASET: # cached_idx_map itself implies the priority of eviction. # The value of self.cached_idx_map represents cpu_row_idx. # The larger it is, the less frequently it will appear in the dataset, # and the higher its eviction priority will be. - _,evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True) + _, evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True) return evict_gpu_row_idxs else: raise TypeError @@ -163,8 +161,12 @@ class CachedParamMgr(torch.nn.Module): reorder the weight according to ids' frequency in dataset before training. Execute only once before training, also known as warmup phase. - :NOTE If you would like to use the DATASET as the eviction strategy, you must call this function. - :NOTE If you are use the LFU as the eviction strategy, you can skip this function. + Note: + If you would like to use the DATASET as the eviction strategy, you must call this function. + + Note: + If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize + The frequency in LFU cache using the dataset statistics. Args: ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. @@ -182,24 +184,31 @@ class CachedParamMgr(torch.nn.Module): with Timer() as timer: # extract rows from cpu weight preload_row_ids = torch.arange(preload_row_num) - preload_slot_ids = preload_row_ids.cuda() + preload_cuda_row_idxs = preload_row_ids.cuda() if self.buffer_size > 0: self.limit_buff_index_copyer.index_copy(0, src_index=preload_row_ids, - tgt_index=preload_slot_ids, + tgt_index=preload_cuda_row_idxs, src=self.weight.view(self.num_embeddings, -1), tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) else: preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() - self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows) + self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs, + preload_rows) # update auxiliary info - slot_offsets = preload_slot_ids - self.cached_idx_map[preload_slot_ids] = preload_slot_ids - if self._evict_strategy == EvictionStrategy.LFU : - self.freq_cnter.index_fill_(0,preload_slot_ids,0) - self.inverted_cached_idx[preload_slot_ids] = slot_offsets + slot_offsets = preload_cuda_row_idxs + self.cached_idx_map[preload_cuda_row_idxs] = preload_cuda_row_idxs + + if self._evict_strategy == EvictionStrategy.LFU: + # if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset. + if ids_freq_mapping is None: + self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0) + else: + self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, self.idx_map[preload_cuda_row_idxs]) + + self.inverted_cached_idx[preload_cuda_row_idxs] = slot_offsets self._cuda_available_row_num -= preload_row_num print(f'Cache warmup finished cost {timer.elapsed} sec.') @@ -215,7 +224,7 @@ class CachedParamMgr(torch.nn.Module): self.inverted_cached_idx.index_fill_(0, row_ids, -1) self._cuda_available_row_num += slots.numel() - if self._evict_strategy == EvictionStrategy.LFU : + if self._evict_strategy == EvictionStrategy.LFU: self.freq_cnter.fill_(sys.maxsize) assert self._cuda_available_row_num == self.cuda_row_num assert torch.all(self.inverted_cached_idx == -1).item() @@ -258,7 +267,7 @@ class CachedParamMgr(torch.nn.Module): torch.Tensor: indices on the cuda_cached_weight. """ with record_function("(zhg) get unique indices"): - cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts = True) + cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) assert len(cpu_row_idxs) <= self.cuda_row_num, \ f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ @@ -283,10 +292,10 @@ class CachedParamMgr(torch.nn.Module): gpu_row_idxs = self._id_to_cached_cuda_id(ids) # update for LFU. - if self._evict_strategy == EvictionStrategy.LFU : + if self._evict_strategy == EvictionStrategy.LFU: unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] - self.freq_cnter.scatter_add_(0,unique_gpu_row_idxs,repeat_times) - + self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times) + return gpu_row_idxs def _reset_comm_stats(self): @@ -363,7 +372,7 @@ class CachedParamMgr(torch.nn.Module): slot_offsets = slots self.cached_idx_map[slots] = cpu_row_idxs self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) - if self._evict_strategy == EvictionStrategy.LFU : + if self._evict_strategy == EvictionStrategy.LFU: self.freq_cnter.index_fill_(0, slots, 0) self._cuda_available_row_num -= cpu_row_idxs.numel() self._cpu_to_cuda_elpase += timer.elapsed @@ -407,7 +416,7 @@ class CachedParamMgr(torch.nn.Module): # update inverted_cached_idx, min_slot_id is evicted from cuda self.cached_idx_map[max_cpu_row_idx] = -1 - if self._evict_strategy == EvictionStrategy.LFU : + if self._evict_strategy == EvictionStrategy.LFU: self.freq_cnter[max_cpu_row_idx] = sys.maxsize self.inverted_cached_idx[max_gpu_row_idx] = -1 @@ -443,7 +452,7 @@ class CachedParamMgr(torch.nn.Module): # update the inverted_cached_idx self.cached_idx_map[slot_id] = row_id - if self._evict_strategy == EvictionStrategy.LFU : + if self._evict_strategy == EvictionStrategy.LFU: self.freq_cnter[slot_id] = 0 self.inverted_cached_idx[row_id] = slot_offset