fixed padding index issue for vocab parallel embedding layers; updated 3D linear to be compatible with examples in the tutorial

This commit is contained in:
zbian
2022-02-17 22:03:39 +08:00
committed by Frank Lee
parent 24f8583cc4
commit 3dba070580
6 changed files with 50 additions and 40 deletions

View File

@@ -41,6 +41,7 @@ def check_linear():
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
@@ -93,6 +94,7 @@ def check_linear():
B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
@@ -301,6 +303,7 @@ def check_vocab_parallel_classifier_no_given_weight():
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
@@ -358,6 +361,7 @@ def check_vocab_parallel_classifier_no_given_weight():
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
@@ -470,6 +474,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
embed.weight.data.copy_(weight)
@@ -518,6 +523,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad,
@@ -710,6 +716,7 @@ def check_vocab_parallel_embed():
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
@@ -751,6 +758,7 @@ def check_vocab_parallel_embed():
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad,