mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
updated tp layers
This commit is contained in:
@@ -20,7 +20,6 @@ def check_linear():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
@@ -32,12 +31,12 @@ def check_linear():
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True)
|
||||
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data.transpose(0, 1)
|
||||
weight_master = layer_master.weight.data.transpose(0, 1).contiguous()
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||
@@ -49,7 +48,7 @@ def check_linear():
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
@@ -72,7 +71,7 @@ def check_linear():
|
||||
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
@@ -108,7 +107,6 @@ def check_layernorm():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -119,7 +117,7 @@ def check_layernorm():
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype)
|
||||
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6)
|
||||
norm = norm.to(device)
|
||||
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
|
||||
norm_master = norm_master.to(device)
|
||||
@@ -134,7 +132,7 @@ def check_layernorm():
|
||||
norm.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
@@ -159,7 +157,7 @@ def check_layernorm():
|
||||
logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
@@ -193,7 +191,6 @@ def check_classifier_no_given_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -204,10 +201,10 @@ def check_classifier_no_given_weight():
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True)
|
||||
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True, dtype=dtype)
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
@@ -219,7 +216,7 @@ def check_classifier_no_given_weight():
|
||||
layer.bias.data.copy_(bias_master)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
@@ -242,7 +239,7 @@ def check_classifier_no_given_weight():
|
||||
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
@@ -283,7 +280,6 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -295,10 +291,10 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
@@ -312,7 +308,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
@@ -336,7 +332,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
@@ -455,7 +451,6 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -466,10 +461,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed = embed.to(device)
|
||||
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
embed_master = embed_master.to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
@@ -479,10 +474,10 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
@@ -504,7 +499,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
@@ -546,12 +541,12 @@ def check_patch_embed():
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
torch.nn.init.ones_(layer.pos_embed)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
@@ -566,7 +561,7 @@ def check_patch_embed():
|
||||
layer.bias.data.copy_(proj_bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
@@ -586,7 +581,7 @@ def check_patch_embed():
|
||||
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
@@ -639,9 +634,9 @@ def check_embed():
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
@@ -669,7 +664,7 @@ def check_embed():
|
||||
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
@@ -686,10 +681,7 @@ def check_embed():
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -709,9 +701,9 @@ def check_vocab_parallel_embed():
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
@@ -741,7 +733,7 @@ def check_vocab_parallel_embed():
|
||||
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
@@ -771,7 +763,6 @@ def check_loss():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -783,8 +774,8 @@ def check_loss():
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
@@ -836,8 +827,8 @@ def check_vocab_parallel_loss():
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
|
@@ -12,8 +12,8 @@ NUM_BLOCKS = 2
|
||||
IMG_SIZE = 16
|
||||
VOCAB_SIZE = 16
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
assert eq
|
||||
return eq
|
||||
|
||||
assert eq, f"\nA = {A}\nB = {B}"
|
||||
return eq
|
@@ -10,9 +10,8 @@ from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus
|
||||
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
|
||||
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
|
||||
check_vocab_parallel_classifier_given_embed_weight,
|
||||
from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear,
|
||||
check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight,
|
||||
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
|
||||
check_vocab_parallel_loss)
|
||||
|
||||
@@ -30,7 +29,6 @@ def check_layer():
|
||||
check_layernorm()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_embed()
|
||||
check_patch_embed()
|
||||
|
Reference in New Issue
Block a user