[embedding] rename FreqAwareEmbedding -> CachedEmbedding (#1699)

This commit is contained in:
Jiarui Fang
2022-10-13 22:22:27 +08:00
committed by GitHub
parent 0e52f3d3d5
commit 21962e1593
8 changed files with 77 additions and 76 deletions

View File

@@ -12,8 +12,8 @@ from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from typing import List
NUM_EMBED, EMBED_DIM = 10, 8
@@ -106,13 +106,13 @@ def test_reorder_with_freq():
def test_freq_aware_embed(use_LFU: bool):
device = torch.device('cuda', 0)
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
model = FreqAwareEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
model = CachedEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
@@ -151,14 +151,14 @@ def test_freq_aware_embed(use_LFU: bool):
@pytest.mark.parametrize('init_freq', [True, False])
def test_lfu_strategy(init_freq: bool):
# minimal test to check behavior
Bag = FreqAwareEmbeddingBag(5,
5,
cache_ratio=3 / 5,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
warmup_ratio=1.0,
evict_strategy=EvictionStrategy.LFU)
Bag = CachedEmbeddingBag(5,
5,
cache_ratio=3 / 5,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
warmup_ratio=1.0,
evict_strategy=EvictionStrategy.LFU)
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
offsets = torch.tensor([0], device="cuda:0")
@@ -233,7 +233,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
_weight = torch.cat([weight_table1, weight_table2], 0)
else:
_weight = weight_table3
model = ParallelFreqAwareEmbeddingBagTablewise(
model = ParallelCachedEmbeddingBagTablewise(
embedding_bag_config_list,
embedding_dim=5,
_weight=_weight,
@@ -300,7 +300,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
model = ParallelFreqAwareEmbeddingBag.from_pretrained(
model = ParallelCachedEmbeddingBag.from_pretrained(
coloweight,
include_last_offset=True,
freeze=False,