mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[FAW] refactor reorder() for CachedParamMgr (#1514)
This commit is contained in:
@@ -144,49 +144,52 @@ def test_freq_aware_embed(use_LFU: bool):
|
||||
assert torch.allclose(model_weight, ref_weight), \
|
||||
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
||||
|
||||
def test_lfu_strategy():
|
||||
# minimal test to check behavior
|
||||
Bag = FreqAwareEmbeddingBag(
|
||||
5,
|
||||
5,
|
||||
cuda_row_num=3,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
warmup_ratio=0.0,
|
||||
evict_strategy=EvictionStrategy.LFU
|
||||
)
|
||||
|
||||
offsets = torch.tensor([0],device="cuda:0")
|
||||
@pytest.mark.parametrize('init_freq', [True, False])
|
||||
def test_lfu_strategy(init_freq: bool):
|
||||
# minimal test to check behavior
|
||||
Bag = FreqAwareEmbeddingBag(5,
|
||||
5,
|
||||
cuda_row_num=3,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
||||
warmup_ratio=1.0,
|
||||
evict_strategy=EvictionStrategy.LFU)
|
||||
|
||||
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
|
||||
offsets = torch.tensor([0], device="cuda:0")
|
||||
|
||||
# prepare frequency learning info:
|
||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||
|
||||
# check strategy
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1
|
||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
|
||||
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 3
|
||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1
|
||||
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
|
||||
Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3
|
||||
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit
|
||||
|
||||
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
|
||||
"LFU strategy behavior failed"
|
||||
|
||||
|
||||
|
||||
def gather_tensor(tensor, rank, world_size):
|
||||
gather_list = []
|
||||
if rank == 0:
|
||||
@@ -279,4 +282,4 @@ def test_parallel_freq_aware_embed(world_size):
|
||||
if __name__ == '__main__':
|
||||
# test_freq_aware_embed(True)
|
||||
# test_parallel_freq_aware_embed(2)
|
||||
test_lfu_strategy()
|
||||
test_lfu_strategy(False)
|
||||
|
Reference in New Issue
Block a user