updated tp layers

This commit is contained in:
kurisusnowdeng
2022-10-26 20:54:39 +08:00
committed by アマデウス
parent cb5a587e9a
commit 0b8161fab8
13 changed files with 645 additions and 293 deletions

View File

@@ -20,7 +20,6 @@ def check_linear():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
@@ -32,12 +31,12 @@ def check_linear():
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True)
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True)
layer = layer.to(device)
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data.transpose(0, 1)
weight_master = layer_master.weight.data.transpose(0, 1).contiguous()
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
@@ -49,7 +48,7 @@ def check_linear():
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
@@ -72,7 +71,7 @@ def check_linear():
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -108,7 +107,6 @@ def check_layernorm():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -119,7 +117,7 @@ def check_layernorm():
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype)
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6)
norm = norm.to(device)
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
norm_master = norm_master.to(device)
@@ -134,7 +132,7 @@ def check_layernorm():
norm.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
@@ -159,7 +157,7 @@ def check_layernorm():
logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@@ -193,7 +191,6 @@ def check_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -204,10 +201,10 @@ def check_classifier_no_given_weight():
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True)
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True)
layer = layer.to(device)
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True, dtype=dtype)
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
@@ -219,7 +216,7 @@ def check_classifier_no_given_weight():
layer.bias.data.copy_(bias_master)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
@@ -242,7 +239,7 @@ def check_classifier_no_given_weight():
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
@@ -283,7 +280,6 @@ def check_vocab_parallel_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -295,10 +291,10 @@ def check_vocab_parallel_classifier_no_given_weight():
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
layer = layer.to(device)
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
@@ -312,7 +308,7 @@ def check_vocab_parallel_classifier_no_given_weight():
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
@@ -336,7 +332,7 @@ def check_vocab_parallel_classifier_no_given_weight():
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -455,7 +451,6 @@ def check_vocab_parallel_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@@ -466,10 +461,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
k = global_context.get_local_rank(output_parallel_mode)
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed = embed.to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
embed_master = embed_master.to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
@@ -479,10 +474,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer = layer.to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
layer_master = layer_master.to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
@@ -504,7 +499,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -546,12 +541,12 @@ def check_patch_embed():
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE)
torch.nn.init.ones_(layer.cls_token)
torch.nn.init.ones_(layer.pos_embed)
layer = layer.to(device)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
@@ -566,7 +561,7 @@ def check_patch_embed():
layer.bias.data.copy_(proj_bias)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
@@ -586,7 +581,7 @@ def check_patch_embed():
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@@ -639,9 +634,9 @@ def check_embed():
k = global_context.get_local_rank(output_parallel_mode)
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(dtype).to(device)
layer = layer.to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(dtype).to(device)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
@@ -669,7 +664,7 @@ def check_embed():
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@@ -686,10 +681,7 @@ def check_embed():
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
else:
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
@@ -709,9 +701,9 @@ def check_vocab_parallel_embed():
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(dtype).to(device)
layer = layer.to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(dtype).to(device)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
@@ -741,7 +733,7 @@ def check_vocab_parallel_embed():
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@@ -771,7 +763,6 @@ def check_loss():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@@ -783,8 +774,8 @@ def check_loss():
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
@@ -836,8 +827,8 @@ def check_vocab_parallel_loss():
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]

View File

@@ -12,8 +12,8 @@ NUM_BLOCKS = 2
IMG_SIZE = 16
VOCAB_SIZE = 16
def check_equal(A, B):
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
assert eq
return eq
assert eq, f"\nA = {A}\nB = {B}"
return eq

View File

@@ -10,9 +10,8 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear,
check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
check_vocab_parallel_loss)
@@ -30,7 +29,6 @@ def check_layer():
check_layernorm()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_embed()
check_patch_embed()