mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[FAW] init an LFU implementation for FAW (#1488)
This commit is contained in:
@@ -12,7 +12,7 @@ from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
||||
ColoTensor, ColoTensorSpec
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy
|
||||
|
||||
NUM_EMBED, EMBED_DIM = 10, 8
|
||||
BATCH_SIZE = 8
|
||||
@@ -41,6 +41,7 @@ def synthesize_1d_sparse_feature(
|
||||
return indices, offsets
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_cachemgr():
|
||||
model = torch.nn.EmbeddingBag(10000, 128)
|
||||
# 10 chunks, 5 in cuda
|
||||
@@ -98,14 +99,17 @@ def test_reorder_with_freq():
|
||||
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
|
||||
|
||||
|
||||
def test_freq_aware_embed():
|
||||
@pytest.mark.parametrize('use_LFU', [True, False])
|
||||
def test_freq_aware_embed(use_LFU: bool):
|
||||
device = torch.device('cuda', 0)
|
||||
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
|
||||
model = FreqAwareEmbeddingBag(NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
cuda_row_num=BATCH_SIZE * 2,
|
||||
ids_freq_mapping=None).to(device)
|
||||
ids_freq_mapping=None,
|
||||
evict_strategy=evict_strategy).to(device)
|
||||
|
||||
assert model.weight.shape[0] == NUM_EMBED
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
||||
@@ -231,6 +235,5 @@ def test_parallel_freq_aware_embed(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cachemgr()
|
||||
# test_freq_aware_embed()
|
||||
test_freq_aware_embed(True)
|
||||
# test_parallel_freq_aware_embed(2)
|
||||
|
Reference in New Issue
Block a user