mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[FAW] LFU cache for the FAW
This commit is contained in:
@@ -144,6 +144,44 @@ def test_freq_aware_embed(use_LFU: bool):
|
||||
assert torch.allclose(model_weight, ref_weight), \
|
||||
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
||||
|
||||
def test_lfu_strategy():
|
||||
# minimal test to check behavior
|
||||
Bag = FreqAwareEmbeddingBag(
|
||||
5,
|
||||
5,
|
||||
cuda_row_num=3,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
warmup_ratio=0.0,
|
||||
evict_strategy=EvictionStrategy.LFU
|
||||
)
|
||||
|
||||
offsets = torch.tensor([0],device="cuda:0")
|
||||
|
||||
# prepare frequency learning info:
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
|
||||
# check strategy
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1
|
||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
|
||||
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 1
|
||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit
|
||||
|
||||
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
|
||||
"LFU strategy behavior failed"
|
||||
|
||||
def gather_tensor(tensor, rank, world_size):
|
||||
gather_list = []
|
||||
@@ -237,3 +275,4 @@ def test_parallel_freq_aware_embed(world_size):
|
||||
if __name__ == '__main__':
|
||||
test_freq_aware_embed(True)
|
||||
# test_parallel_freq_aware_embed(2)
|
||||
# test_lfu_strategy()
|
Reference in New Issue
Block a user