[embedding] polish parallel embedding tablewise (#1545)

This commit is contained in:
Jiarui Fang
2022-09-06 10:41:20 +08:00
committed by GitHub
parent 46c6cc79a9
commit 64169f3e8f
6 changed files with 232 additions and 204 deletions

View File

@@ -13,7 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from typing import List
NUM_EMBED, EMBED_DIM = 10, 8
@@ -209,19 +209,28 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
# initialize weight
# 3 feature tables. idx: 0~5, 6~10, 11~17
weight_tables = torch.rand(18,5)
weight_tables = torch.rand(18, 5)
weight_table1 = weight_tables[0:6]
weight_table2 = weight_tables[6:11]
weight_table3 = weight_tables[11:18]
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu()))
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu()))
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu()))
embedding_bag_config_list.append(
TablewiseEmbeddingBagConfig(num_embeddings=6,
cuda_row_num=4,
assigned_rank=0,
initial_weight=weight_table1.clone().detach().cpu()))
embedding_bag_config_list.append(
TablewiseEmbeddingBagConfig(num_embeddings=5,
cuda_row_num=4,
assigned_rank=0,
initial_weight=weight_table2.clone().detach().cpu()))
embedding_bag_config_list.append(
TablewiseEmbeddingBagConfig(num_embeddings=7,
cuda_row_num=4,
assigned_rank=1,
initial_weight=weight_table3.clone().detach().cpu()))
if rank == 0:
_weight = torch.cat([weight_table1, weight_table2],0)
_weight = torch.cat([weight_table1, weight_table2], 0)
else:
_weight = weight_table3
model = ParallelFreqAwareEmbeddingBagTablewise(
@@ -249,30 +258,31 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
if rank == 0:
fake_grad = rand_grad[0:2]
else :
else:
fake_grad = rand_grad[2:]
res.backward(fake_grad)
optimizer.step()
optimizer.zero_grad()
# check correctness
# check correctness
if rank == 0:
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(),
include_last_offset=True,
freeze=False).to(device)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
ref_fake_grad = torch.cat(rand_grad.split(5,1),0)
ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0)
ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
ref_res.backward(ref_fake_grad)
ref_optimizer.step()
ref_optimizer.zero_grad()
model.cache_weight_mgr.flush()
recover_weight = model.cache_weight_mgr.weight.to(device)
ref_weight = ref_model.weight.detach()[:11]
assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
device = torch.device('cuda', torch.cuda.current_device())
@@ -289,11 +299,12 @@ def run_parallel_freq_aware_embed_columnwise(rank, world_size):
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
include_last_offset=True,
freeze=False,
cuda_row_num=batch_size * 2,
)
model = ParallelFreqAwareEmbeddingBag.from_pretrained(
coloweight,
include_last_offset=True,
freeze=False,
cuda_row_num=batch_size * 2,
)
assert model.cache_weight_mgr.weight.device.type == 'cpu'
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad