mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[FAW] reorganize the inheritance struct of FreqCacheEmbedding (#1448)
This commit is contained in:
@@ -10,7 +10,8 @@ import torch.multiprocessing as mp
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
||||
ColoTensor, ColoTensorSpec
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||
|
||||
NUM_EMBED, EMBED_DIM = 10, 8
|
||||
@@ -99,13 +100,12 @@ def test_reorder_with_freq():
|
||||
|
||||
def test_freq_aware_embed():
|
||||
device = torch.device('cuda', 0)
|
||||
model = FreqAwareEmbeddingBag(
|
||||
NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
).to(device)
|
||||
model.preprocess(cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None)
|
||||
model = FreqAwareEmbeddingBag(NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
cuda_row_num=BATCH_SIZE * 2,
|
||||
ids_freq_mapping=None).to(device)
|
||||
|
||||
assert model.weight.shape[0] == NUM_EMBED
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
||||
@@ -159,11 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size):
|
||||
|
||||
set_seed(4321)
|
||||
weight = torch.rand(num_embed, embed_dim)
|
||||
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
|
||||
coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None)
|
||||
|
||||
# initialize the tensor spec for the embedding weight parameter,
|
||||
# which is an ColoParameter.
|
||||
coloweight.process_group = ProcessGroup(tp_degree=world_size)
|
||||
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
|
||||
@@ -171,12 +171,12 @@ def run_parallel_freq_aware_embed(rank, world_size):
|
||||
freeze=False,
|
||||
cuda_row_num=batch_size * 2)
|
||||
|
||||
assert model.cache_weight_mgr.cpu_weight.device.type == 'cpu'
|
||||
assert model.cache_weight_mgr.weight.device.type == 'cpu'
|
||||
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
||||
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
|
||||
assert torch.allclose(
|
||||
weight_in_rank,
|
||||
model.cache_weight_mgr.cpu_weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.cpu_weight}"
|
||||
print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}")
|
||||
assert torch.allclose(weight_in_rank,
|
||||
model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}"
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
|
||||
@@ -211,7 +211,7 @@ def run_parallel_freq_aware_embed(rank, world_size):
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
model.cache_weight_mgr.flush()
|
||||
weight_list = gather_tensor(model.cache_weight_mgr.cpu_weight.detach().cuda(), rank, world_size)
|
||||
weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size)
|
||||
if rank == 0:
|
||||
recover_weight = torch.cat(weight_list, dim=1)
|
||||
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
|
||||
@@ -231,6 +231,6 @@ def test_parallel_freq_aware_embed(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_cachemgr()
|
||||
# test_freq_aware_embed()
|
||||
# test_chunkmgr_admit()
|
||||
test_parallel_freq_aware_embed(2)
|
||||
|
Reference in New Issue
Block a user