mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[embedding] freq_aware_embedding: add small functions for caller application (#1537)
This commit is contained in:
@@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
||||
ColoTensor, ColoTensorSpec
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
|
||||
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
|
||||
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
|
||||
from typing import List
|
||||
|
||||
NUM_EMBED, EMBED_DIM = 10, 8
|
||||
@@ -209,9 +209,10 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||
|
||||
# initialize weight
|
||||
# 3 feature tables. idx: 0~5, 6~10, 11~17
|
||||
weight_table1 = torch.rand(6, 5)
|
||||
weight_table2 = torch.rand(5, 5)
|
||||
weight_table3 = torch.rand(7, 5)
|
||||
weight_tables = torch.rand(18,5)
|
||||
weight_table1 = weight_tables[0:6]
|
||||
weight_table2 = weight_tables[6:11]
|
||||
weight_table3 = weight_tables[11:18]
|
||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
||||
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
|
||||
num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu()))
|
||||
@@ -219,14 +220,20 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||
num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu()))
|
||||
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
|
||||
num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu()))
|
||||
|
||||
if rank == 0:
|
||||
_weight = torch.cat([weight_table1, weight_table2],0)
|
||||
else:
|
||||
_weight = weight_table3
|
||||
model = ParallelFreqAwareEmbeddingBagTablewise(
|
||||
embedding_bag_config_list,
|
||||
embedding_dim=5,
|
||||
_weight=_weight,
|
||||
include_last_offset=True,
|
||||
cuda_row_num=8,
|
||||
buffer_size=0,
|
||||
evict_strategy=EvictionStrategy.LFU,
|
||||
include_last_offset=True
|
||||
)
|
||||
# demo explain:
|
||||
# explain
|
||||
'''
|
||||
batch feature 1 feature 2 feature 3
|
||||
input0 [1,2,3] [6,7] []
|
||||
@@ -244,28 +251,27 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||
fake_grad = rand_grad[0:2]
|
||||
else :
|
||||
fake_grad = rand_grad[2:]
|
||||
|
||||
res.backward(fake_grad)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# check correctness on weight_table2
|
||||
# check correctness
|
||||
if rank == 0:
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_table2.detach().clone(),
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(),
|
||||
include_last_offset=True,
|
||||
freeze=False).to(device)
|
||||
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
|
||||
ref_grad = rand_grad[:, 5:10]
|
||||
ref_res = ref_model(torch.tensor([0, 1, 3, 0, 2], device=device), torch.tensor([0, 2, 3, 5], device=device))
|
||||
ref_res.backward(ref_grad)
|
||||
ref_fake_grad = torch.cat(rand_grad.split(5,1),0)
|
||||
ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
||||
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
|
||||
ref_res.backward(ref_fake_grad)
|
||||
ref_optimizer.step()
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
model.freq_aware_embedding_bag_list[1].cache_weight_mgr.flush() # update cpu weight
|
||||
recover_weight = model.freq_aware_embedding_bag_list[1].cache_weight_mgr.weight
|
||||
assert torch.allclose(recover_weight, ref_model.weight.detach().cpu()
|
||||
), f"{recover_weight - ref_model.weight.detach().cpu()}"
|
||||
|
||||
|
||||
model.cache_weight_mgr.flush()
|
||||
recover_weight = model.cache_weight_mgr.weight.to(device)
|
||||
ref_weight = ref_model.weight.detach()[:11]
|
||||
assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
|
||||
|
||||
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
|
Reference in New Issue
Block a user