mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
[FAW] parallel FreqAwareEmbedding (#1424)
This commit is contained in:
@@ -3,9 +3,13 @@ from functools import partial
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||
|
||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag
|
||||
|
||||
@@ -13,6 +17,15 @@ NUM_EMBED, EMBED_DIM = 10, 8
|
||||
BATCH_SIZE = 8
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
To achieve reproducible results, it's necessary to fix random seeds
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def synthesize_1d_sparse_feature(
|
||||
batch_size,
|
||||
num_embed,
|
||||
@@ -128,7 +141,91 @@ def test_freq_aware_embed():
|
||||
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
||||
|
||||
|
||||
def gather_tensor(tensor, rank, world_size):
|
||||
gather_list = []
|
||||
if rank == 0:
|
||||
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
|
||||
torch.distributed.gather(tensor, gather_list, dst=0)
|
||||
return gather_list
|
||||
|
||||
|
||||
def run_parallel_freq_aware_embed(rank, world_size):
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
|
||||
num_embed = 100
|
||||
embed_dim = 16
|
||||
batch_size = 4
|
||||
|
||||
set_seed(4321)
|
||||
weight = torch.rand(num_embed, embed_dim)
|
||||
coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False)
|
||||
|
||||
model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight,
|
||||
include_last_offset=True,
|
||||
freeze=False,
|
||||
cuda_row_num=batch_size * 2)
|
||||
|
||||
assert model.cache_weight_mgr.cpu_weight.device.type == 'cpu'
|
||||
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
||||
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
|
||||
assert torch.allclose(
|
||||
weight_in_rank,
|
||||
model.cache_weight_mgr.cpu_weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.cpu_weight}"
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
|
||||
if rank == 0:
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(),
|
||||
include_last_offset=True,
|
||||
freeze=False).to(device)
|
||||
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
|
||||
|
||||
set_seed(4321)
|
||||
for i in range(5):
|
||||
indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device)
|
||||
res = model(indices, offsets)
|
||||
|
||||
grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device)
|
||||
grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank]
|
||||
res.backward(grad_in_rank)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
res_list = gather_tensor(res.detach(), rank, world_size)
|
||||
|
||||
if rank == 0:
|
||||
ref_res = ref_model(indices, offsets)
|
||||
recover_res = torch.cat(res_list, dim=0)
|
||||
|
||||
assert torch.allclose(ref_res, recover_res)
|
||||
|
||||
ref_res.backward(grad)
|
||||
ref_optimizer.step()
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
model.cache_weight_mgr.flush()
|
||||
weight_list = gather_tensor(model.cache_weight_mgr.cpu_weight.detach().cuda(), rank, world_size)
|
||||
if rank == 0:
|
||||
recover_weight = torch.cat(weight_list, dim=1)
|
||||
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_parallel_freq_aware_embed(rank, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_parallel_freq_aware_embed(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_freq_aware_embed()
|
||||
# test_chunkmgr_admit()
|
||||
pass
|
||||
test_parallel_freq_aware_embed(2)
|
||||
|
Reference in New Issue
Block a user