[FAW] init an LFU implementation for FAW (#1488)

This commit is contained in:
Jiarui Fang
2022-08-24 17:37:22 +08:00
committed by GitHub
parent 32efe8e740
commit cde7b8a5b8
5 changed files with 112 additions and 39 deletions

View File

@@ -12,7 +12,7 @@ from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy
NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8
@@ -41,6 +41,7 @@ def synthesize_1d_sparse_feature(
return indices, offsets
@pytest.mark.skip
def test_cachemgr():
model = torch.nn.EmbeddingBag(10000, 128)
# 10 chunks, 5 in cuda
@@ -98,14 +99,17 @@ def test_reorder_with_freq():
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
def test_freq_aware_embed():
@pytest.mark.parametrize('use_LFU', [True, False])
def test_freq_aware_embed(use_LFU: bool):
device = torch.device('cuda', 0)
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
model = FreqAwareEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cuda_row_num=BATCH_SIZE * 2,
ids_freq_mapping=None).to(device)
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
@@ -231,6 +235,5 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__':
test_cachemgr()
# test_freq_aware_embed()
test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)