[embedding] add tablewise sharding for FAW (#1526)

This commit is contained in:
CsRic
2022-09-01 17:55:41 +08:00
committed by GitHub
parent f1e1836218
commit 5156d5b4f8
6 changed files with 273 additions and 13 deletions

View File

@@ -12,7 +12,9 @@ from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \
ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig
from typing import List
NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8
@@ -200,7 +202,72 @@ def gather_tensor(tensor, rank, world_size):
return gather_list
def run_parallel_freq_aware_embed(rank, world_size):
def run_parallel_freq_aware_embed_tablewise(rank, world_size):
if world_size != 2:
return
device = torch.device('cuda', torch.cuda.current_device())
# initialize weight
# 3 feature tables. idx: 0~5, 6~10, 11~17
weight_table1 = torch.rand(6, 5)
weight_table2 = torch.rand(5, 5)
weight_table3 = torch.rand(7, 5)
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu()))
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu()))
embedding_bag_config_list.append(TablewiseEmbeddingBagConfig(
num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu()))
model = ParallelFreqAwareEmbeddingBagTablewise(
embedding_bag_config_list,
embedding_dim=5,
evict_strategy=EvictionStrategy.LFU,
include_last_offset=True
)
# demo explain:
'''
batch feature 1 feature 2 feature 3
input0 [1,2,3] [6,7] []
input1 [] [9] [13,15]
input2 [1,5] [6,8] [11]
↑ ↑ ↑
rank 0 rank 0 rank 1
in KJT format
'''
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
if rank == 0:
fake_grad = rand_grad[0:2]
else :
fake_grad = rand_grad[2:]
res.backward(fake_grad)
optimizer.step()
optimizer.zero_grad()
# check correctness on weight_table2
if rank == 0:
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_table2.detach().clone(),
include_last_offset=True,
freeze=False).to(device)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
ref_grad = rand_grad[:, 5:10]
ref_res = ref_model(torch.tensor([0, 1, 3, 0, 2], device=device), torch.tensor([0, 2, 3, 5], device=device))
ref_res.backward(ref_grad)
ref_optimizer.step()
ref_optimizer.zero_grad()
model.freq_aware_embedding_bag_list[1].cache_weight_mgr.flush() # update cpu weight
recover_weight = model.freq_aware_embedding_bag_list[1].cache_weight_mgr.weight
assert torch.allclose(recover_weight, ref_model.weight.detach().cpu()
), f"{recover_weight - ref_model.weight.detach().cpu()}"
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
device = torch.device('cuda', torch.cuda.current_device())
num_embed = 100
@@ -219,7 +286,8 @@ def run_parallel_freq_aware_embed(rank, world_size):
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
include_last_offset=True,
freeze=False,
cuda_row_num=batch_size * 2)
cuda_row_num=batch_size * 2,
)
assert model.cache_weight_mgr.weight.device.type == 'cpu'
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
@@ -269,7 +337,8 @@ def run_parallel_freq_aware_embed(rank, world_size):
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_parallel_freq_aware_embed(rank, world_size)
# run_parallel_freq_aware_embed_columnwise(rank, world_size)
run_parallel_freq_aware_embed_tablewise(rank, world_size)
@pytest.mark.dist
@@ -281,6 +350,6 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__':
test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
# test_freq_aware_embed(True)
test_parallel_freq_aware_embed(2)
# test_lfu_strategy(False)