[FAW] reorganize the inheritance struct of FreqCacheEmbedding (#1448)

This commit is contained in:
Geng Zhang
2022-08-12 15:55:46 +08:00
committed by GitHub
parent 5a52e21fe3
commit 9f3eed66eb
4 changed files with 189 additions and 150 deletions

View File

@@ -10,7 +10,8 @@ import torch.multiprocessing as mp
import colossalai
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
NUM_EMBED, EMBED_DIM = 10, 8
@@ -99,13 +100,12 @@ def test_reorder_with_freq():
def test_freq_aware_embed():
device = torch.device('cuda', 0)
model = FreqAwareEmbeddingBag(
NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
).to(device)
model.preprocess(cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None)
model = FreqAwareEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cuda_row_num=BATCH_SIZE * 2,
ids_freq_mapping=None).to(device)
assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
@@ -159,11 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size):
set_seed(4321)
weight = torch.rand(num_embed, embed_dim)
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None)
# initialize the tensor spec for the embedding weight parameter,
# which is an ColoParameter.
coloweight.process_group = ProcessGroup(tp_degree=world_size)
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
@@ -171,12 +171,12 @@ def run_parallel_freq_aware_embed(rank, world_size):
freeze=False,
cuda_row_num=batch_size * 2)
assert model.cache_weight_mgr.cpu_weight.device.type == 'cpu'
assert model.cache_weight_mgr.weight.device.type == 'cpu'
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
assert torch.allclose(
weight_in_rank,
model.cache_weight_mgr.cpu_weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.cpu_weight}"
print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}")
assert torch.allclose(weight_in_rank,
model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}"
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
@@ -211,7 +211,7 @@ def run_parallel_freq_aware_embed(rank, world_size):
ref_optimizer.zero_grad()
model.cache_weight_mgr.flush()
weight_list = gather_tensor(model.cache_weight_mgr.cpu_weight.detach().cuda(), rank, world_size)
weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size)
if rank == 0:
recover_weight = torch.cat(weight_list, dim=1)
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
@@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__':
# test_cachemgr()
# test_freq_aware_embed()
# test_chunkmgr_admit()
test_parallel_freq_aware_embed(2)