[embeddings] add already_split_along_rank flag for tablewise mode (#1584)

This commit is contained in:
CsRic
2022-09-13 10:50:34 +08:00
committed by GitHub
parent 77399dc91b
commit f3403ff98e
2 changed files with 41 additions and 17 deletions

View File

@@ -253,7 +253,8 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size):
in KJT format
'''
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
already_split_along_rank=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
if rank == 0: