moved env variables to global variables; (#215)

added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
This commit is contained in:
アマデウス
2022-02-14 11:15:02 +08:00
committed by Frank Lee
parent b82d60be02
commit 9ee197d0e9
63 changed files with 4304 additions and 1040 deletions

View File

@@ -3,16 +3,17 @@
import time
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D)
import torch
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context
from colossalai.logging import get_dist_logger
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VanillaClassifier,
VanillaPatchEmbedding)
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D,
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D,
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D)
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.utils import get_current_device, print_rank_0
from .common import *
import torch
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear():
@@ -27,9 +28,9 @@ def check_linear():
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True)
layer = layer.to(device)
@@ -112,9 +113,9 @@ def check_layernorm():
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype)
norm = norm.to(device)
@@ -186,7 +187,7 @@ def check_layernorm():
return fwd_end - fwd_start, bwd_end - bwd_start
def check_classifier():
def check_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
@@ -197,9 +198,9 @@ def check_classifier():
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True)
layer = layer.to(device)
@@ -229,14 +230,14 @@ def check_classifier():
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
logger)
'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} head forward: {}'.format(rank, check_equal(out, C)))
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
@@ -249,7 +250,7 @@ def check_classifier():
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
@@ -257,23 +258,275 @@ def check_classifier():
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} head backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} head backward (weight_grad): {}'.format(rank,
check_equal(B_grad, layer.weight.grad)))
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
else:
logger.info('Rank {} head backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
rank, layer.weight.grad is None))
bias_grad = layer_master.bias.grad
logger.info('Rank {} head backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_embed():
def check_vocab_parallel_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
embed = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
embed.weight.data.copy_(weight)
layer = Classifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(embed(A))
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, embed.weight.grad)))
else:
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
rank, embed.weight.grad is None))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(embed(A))
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad,
embed.weight.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_patch_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
@@ -283,9 +536,9 @@ def check_embed():
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer.cls_token)
@@ -310,18 +563,99 @@ def check_embed():
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start), logger)
'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank,
check_equal(pos_grad, layer.pos_embed.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank,
check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank,
check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(dtype).to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start),
ranks=[0])
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
@@ -329,7 +663,7 @@ def check_embed():
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
@@ -339,30 +673,88 @@ def check_embed():
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
grad_master = grad_master.clone()
C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
else:
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(rank, check_equal(pos_grad, layer.pos_embed.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(dtype).to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
ranks=[0])
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
if j == k:
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, check_equal(B_grad,
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad,
layer.weight.grad)))
else:
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, layer.weight.grad is None))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} embed backward (proj_bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
@@ -375,11 +767,9 @@ def check_loss():
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = A_rank = global_context.get_local_rank(input_parallel_mode)
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
k = C_rank = global_context.get_local_rank(output_parallel_mode)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
criterion = CrossEntropyLoss3D()
criterion_master = torch.nn.CrossEntropyLoss()
@@ -397,24 +787,79 @@ def check_loss():
fwd_start = time.time()
loss = criterion(out, target_master)
fwd_end = time.time()
print_rank_0(
'loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
logger)
logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape),
fwd_end - fwd_start),
ranks=[0])
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
logger.info('Rank {} CrossEntropyLoss forward: {}'.format(rank, check_equal(loss, loss_master)))
logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
bwd_start = time.time()
loss.backward()
bwd_end = time.time()
print_rank_0('loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
logger.info('Rank {} CrossEntropyLoss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_loss():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
criterion = VocabParallelCrossEntropyLoss3D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = torch.chunk(out, DEPTH, dim=-1)[k]
out = torch.chunk(out, DEPTH, dim=0)[j]
out = out.clone()
out.requires_grad = True
fwd_start = time.time()
loss = criterion(out, target_master)
fwd_end = time.time()
logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
ranks=[0])
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
bwd_start = time.time()
loss.backward()
bwd_end = time.time()
logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k]
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start

View File

@@ -10,6 +10,7 @@ HIDDEN_SIZE = 8
NUM_CLASSES = 8
NUM_BLOCKS = 2
IMG_SIZE = 16
VOCAB_SIZE = 16
def check_equal(A, B):
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)

View File

@@ -7,9 +7,14 @@ import torch
import torch.multiprocessing as mp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port
from checks_3d.check_layer_3d import *
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
check_vocab_parallel_loss)
CONFIG = dict(
parallel=dict(
@@ -23,13 +28,23 @@ CONFIG = dict(
def check_layer():
check_linear()
check_layernorm()
check_classifier()
# check_embed()
# check_loss()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_embed()
check_patch_embed()
check_vocab_parallel_embed()
check_loss()
check_vocab_parallel_loss()
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = True
check_layer()
gpc.destroy()
torch.cuda.empty_cache()