mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 18:39:56 +00:00
[FAW] LFU cache for the FAW
This commit is contained in:
@@ -59,7 +59,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||
"""_update_freq_cnter
|
||||
|
||||
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
|
||||
|
||||
|
||||
Args:
|
||||
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
|
||||
"""
|
||||
@@ -80,7 +80,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# find the minimal evict_num freq entries in cached_idx_map
|
||||
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
|
||||
return self.cached_idx_map[evict_gpu_row_idxs]
|
||||
return evict_gpu_row_idxs
|
||||
elif self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# cached_idx_map itself implies the priority of eviction.
|
||||
# The value of self.cached_idx_map represents cpu_row_idx.
|
||||
@@ -298,15 +298,27 @@ class CachedParamMgr(torch.nn.Module):
|
||||
if evict_num > 0:
|
||||
with Timer() as timer:
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# mask method.
|
||||
# set cached_idx_map[invalid_idxs] to -2.
|
||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||
# another mask method.
|
||||
# set freq_cnter[invalid_idxs] to max
|
||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||
backup_cnter = self.freq_cnter[invalid_idxs].clone()
|
||||
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value?
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter)
|
||||
|
||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||
|
||||
if self.buffer_size > 0:
|
||||
|
Reference in New Issue
Block a user