[embeddings] use cache_ratio instead of cuda_row_num (#1611)

This commit is contained in:
Jiarui Fang
2022-09-20 14:33:04 +08:00
committed by GitHub
parent 6a8f8cc05e
commit 504ff1d101
5 changed files with 16 additions and 14 deletions

View File

@@ -110,7 +110,7 @@ def test_freq_aware_embed(use_LFU: bool):
EMBED_DIM,
mode='mean',
include_last_offset=True,
cuda_row_num=BATCH_SIZE * 2,
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
@@ -153,7 +153,7 @@ def test_lfu_strategy(init_freq: bool):
# minimal test to check behavior
Bag = FreqAwareEmbeddingBag(5,
5,
cuda_row_num=3,
cache_ratio=3 / 5,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
@@ -238,7 +238,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
embedding_dim=5,
_weight=_weight,
include_last_offset=True,
cuda_row_num=8,
cache_ratio=0.5,
buffer_size=0,
evict_strategy=EvictionStrategy.LFU,
)
@@ -304,7 +304,7 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
coloweight,
include_last_offset=True,
freeze=False,
cuda_row_num=batch_size * 2,
cache_ratio=batch_size * 2 / num_embed,
)
assert model.cache_weight_mgr.weight.device.type == 'cpu'