[FAW] LFU cache for the FAW

This commit is contained in:
CsRic
2022-08-25 13:08:46 +08:00
committed by GitHub
parent 9145aef2b4
commit b8d0e39eaf
2 changed files with 60 additions and 9 deletions

View File

@@ -144,6 +144,44 @@ def test_freq_aware_embed(use_LFU: bool):
assert torch.allclose(model_weight, ref_weight), \
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
def test_lfu_strategy():
# minimal test to check behavior
Bag = FreqAwareEmbeddingBag(
5,
5,
cuda_row_num=3,
buffer_size=0,
pin_weight=True,
warmup_ratio=0.0,
evict_strategy=EvictionStrategy.LFU
)
offsets = torch.tensor([0],device="cuda:0")
# prepare frequency learning info:
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
# check strategy
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 1
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
"LFU strategy behavior failed"
def gather_tensor(tensor, rank, world_size):
gather_list = []
@@ -237,3 +275,4 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__':
test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
# test_lfu_strategy()