[FAW] LFU cache for the FAW

This commit is contained in:
CsRic
2022-08-25 13:08:46 +08:00
committed by GitHub
parent 9145aef2b4
commit b8d0e39eaf
2 changed files with 60 additions and 9 deletions

View File

@@ -59,7 +59,7 @@ class CachedParamMgr(torch.nn.Module):
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
Args:
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
@@ -80,7 +80,7 @@ class CachedParamMgr(torch.nn.Module):
if self._evict_strategy == EvictionStrategy.LFU:
# find the minimal evict_num freq entries in cached_idx_map
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
return self.cached_idx_map[evict_gpu_row_idxs]
return evict_gpu_row_idxs
elif self._evict_strategy == EvictionStrategy.DATASET:
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
@@ -298,15 +298,27 @@ class CachedParamMgr(torch.nn.Module):
if evict_num > 0:
with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
if self._evict_strategy == EvictionStrategy.DATASET:
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU:
# another mask method.
# set freq_cnter[invalid_idxs] to max
# so those idxs will be sorted to end, therefore not being chosen as victim
backup_cnter = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value?
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
if self.buffer_size > 0: