mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[FAW] refactor reorder() for CachedParamMgr (#1514)
This commit is contained in:
@@ -172,44 +172,53 @@ class CachedParamMgr(torch.nn.Module):
|
||||
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||
"""
|
||||
if ids_freq_mapping is not None:
|
||||
if not isinstance(ids_freq_mapping, torch.Tensor):
|
||||
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||
sorted_idx = torch.argsort(tmp_idx)
|
||||
self.idx_map.data.copy_(sorted_idx)
|
||||
# reorder phase: reorder the cpu weight according to their freq stats in the target dataset.
|
||||
# reorder only works for DATASET eviction strategy.
|
||||
|
||||
if ids_freq_mapping is not None and not isinstance(ids_freq_mapping, torch.Tensor):
|
||||
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||
if ids_freq_mapping is not None:
|
||||
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||
sorted_idx = torch.argsort(tmp_idx)
|
||||
self.idx_map.data.copy_(sorted_idx)
|
||||
|
||||
# warmup phase: copy #preload_row_num rows from cpu to gpu.
|
||||
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
||||
if preload_row_num > 0:
|
||||
with Timer() as timer:
|
||||
# extract rows from cpu weight
|
||||
preload_row_ids = torch.arange(preload_row_num)
|
||||
preload_cuda_row_idxs = preload_row_ids.cuda()
|
||||
if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None:
|
||||
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
|
||||
preload_cuda_row_idxs = torch.arange(preload_row_num).cuda()
|
||||
else:
|
||||
preload_cpu_ids = torch.arange(preload_row_num)
|
||||
preload_cuda_row_idxs = preload_cpu_ids.cuda()
|
||||
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=preload_row_ids,
|
||||
src_index=preload_cpu_ids,
|
||||
tgt_index=preload_cuda_row_idxs,
|
||||
src=self.weight.view(self.num_embeddings, -1),
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
|
||||
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
|
||||
preload_rows)
|
||||
|
||||
# update auxiliary info
|
||||
slot_offsets = preload_cuda_row_idxs
|
||||
self.cached_idx_map[preload_cuda_row_idxs] = preload_cuda_row_idxs
|
||||
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda()
|
||||
self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs
|
||||
self._cuda_available_row_num -= preload_row_num
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
|
||||
if ids_freq_mapping is None:
|
||||
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
|
||||
else:
|
||||
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, self.idx_map[preload_cuda_row_idxs])
|
||||
self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda()
|
||||
|
||||
self.inverted_cached_idx[preload_cuda_row_idxs] = slot_offsets
|
||||
self._cuda_available_row_num -= preload_row_num
|
||||
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
||||
|
||||
def flush(self):
|
||||
|
Reference in New Issue
Block a user