[FAW] refactor reorder() for CachedParamMgr (#1514)

This commit is contained in:
Jiarui Fang
2022-08-29 14:22:07 +08:00
committed by GitHub
parent 9feee6d06b
commit af5438caa2
2 changed files with 63 additions and 51 deletions

View File

@@ -144,49 +144,52 @@ def test_freq_aware_embed(use_LFU: bool):
assert torch.allclose(model_weight, ref_weight), \
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
def test_lfu_strategy():
# minimal test to check behavior
Bag = FreqAwareEmbeddingBag(
5,
5,
cuda_row_num=3,
buffer_size=0,
pin_weight=True,
warmup_ratio=0.0,
evict_strategy=EvictionStrategy.LFU
)
offsets = torch.tensor([0],device="cuda:0")
@pytest.mark.parametrize('init_freq', [True, False])
def test_lfu_strategy(init_freq: bool):
# minimal test to check behavior
Bag = FreqAwareEmbeddingBag(5,
5,
cuda_row_num=3,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
warmup_ratio=1.0,
evict_strategy=EvictionStrategy.LFU)
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
offsets = torch.tensor([0], device="cuda:0")
# prepare frequency learning info:
Bag.forward(torch.tensor([2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
# check strategy
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 3
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
"LFU strategy behavior failed"
def gather_tensor(tensor, rank, world_size):
gather_list = []
if rank == 0:
@@ -279,4 +282,4 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__':
# test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
test_lfu_strategy()
test_lfu_strategy(False)