Add FreqAwareEmbeddingBag (#1421)

This commit is contained in:
Jiarui Fang
2022-08-09 16:26:12 +08:00
committed by GitHub
parent 6df3e19be9
commit d209aff684
3 changed files with 145 additions and 3 deletions

View File

@@ -7,12 +7,26 @@ import numpy as np
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.nn._ops.cache_embedding import CachedParamMgr
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag
NUM_EMBED, EMBED_DIM = 100, 8
NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8
def synthesize_1d_sparse_feature(
batch_size,
num_embed,
device,
):
indices_in_batch = batch_size * 2
indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long)
offsets = torch.from_numpy(
np.array([
0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch
])).to(device).long()
return indices, offsets
def test_cachemgr():
model = torch.nn.EmbeddingBag(10000, 128)
# 10 chunks, 5 in cuda
@@ -70,6 +84,50 @@ def test_reorder_with_freq():
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
def test_freq_aware_embed():
device = torch.device('cuda', 0)
model = FreqAwareEmbeddingBag(
NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
).to(device)
model.preprocess(cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None)
assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
mode='mean',
include_last_offset=True,
freeze=False)
assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
for i in range(5):
indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device)
res = model(indices, offsets)
ref_res = ref_model(indices, offsets)
assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}"
grad = torch.rand_like(res)
# comparing gradient here is nontrivial
res.backward(grad)
ref_res.backward(grad)
optimizer.step()
optimizer.zero_grad()
ref_optimizer.step()
ref_optimizer.zero_grad()
model.cache_weight_mgr.flush()
model_weight = model.weight.detach().to(device)
ref_weight = ref_model.weight.detach()
assert torch.allclose(model_weight, ref_weight), \
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
if __name__ == '__main__':
# test_freq_aware_embed()
# test_chunkmgr_admit()