mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[embedding] rename FreqAwareEmbedding -> CachedEmbedding (#1699)
This commit is contained in:
@@ -12,8 +12,8 @@ from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
||||
ColoTensor, ColoTensorSpec
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
|
||||
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \
|
||||
ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
|
||||
from typing import List
|
||||
|
||||
NUM_EMBED, EMBED_DIM = 10, 8
|
||||
@@ -106,13 +106,13 @@ def test_reorder_with_freq():
|
||||
def test_freq_aware_embed(use_LFU: bool):
|
||||
device = torch.device('cuda', 0)
|
||||
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
|
||||
model = FreqAwareEmbeddingBag(NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
|
||||
ids_freq_mapping=None,
|
||||
evict_strategy=evict_strategy).to(device)
|
||||
model = CachedEmbeddingBag(NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
|
||||
ids_freq_mapping=None,
|
||||
evict_strategy=evict_strategy).to(device)
|
||||
|
||||
assert model.weight.shape[0] == NUM_EMBED
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
||||
@@ -151,14 +151,14 @@ def test_freq_aware_embed(use_LFU: bool):
|
||||
@pytest.mark.parametrize('init_freq', [True, False])
|
||||
def test_lfu_strategy(init_freq: bool):
|
||||
# minimal test to check behavior
|
||||
Bag = FreqAwareEmbeddingBag(5,
|
||||
5,
|
||||
cache_ratio=3 / 5,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
||||
warmup_ratio=1.0,
|
||||
evict_strategy=EvictionStrategy.LFU)
|
||||
Bag = CachedEmbeddingBag(5,
|
||||
5,
|
||||
cache_ratio=3 / 5,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
||||
warmup_ratio=1.0,
|
||||
evict_strategy=EvictionStrategy.LFU)
|
||||
|
||||
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
|
||||
offsets = torch.tensor([0], device="cuda:0")
|
||||
@@ -233,7 +233,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||
_weight = torch.cat([weight_table1, weight_table2], 0)
|
||||
else:
|
||||
_weight = weight_table3
|
||||
model = ParallelFreqAwareEmbeddingBagTablewise(
|
||||
model = ParallelCachedEmbeddingBagTablewise(
|
||||
embedding_bag_config_list,
|
||||
embedding_dim=5,
|
||||
_weight=_weight,
|
||||
@@ -300,7 +300,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
||||
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
model = ParallelFreqAwareEmbeddingBag.from_pretrained(
|
||||
model = ParallelCachedEmbeddingBag.from_pretrained(
|
||||
coloweight,
|
||||
include_last_offset=True,
|
||||
freeze=False,
|
||||
|
Reference in New Issue
Block a user