mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
moved env variables to global variables; (#215)
added branch context; added vocab parallel layers; moved split_batch from load_batch to tensor parallel embedding layers; updated gpt model; updated unit test cases; fixed few collective communicator bugs
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Parameter
|
||||
import time
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.nn import (Classifier1D, Embedding1D, Linear1D_Col, Linear1D_Row, VanillaClassifier,
|
||||
VocabParallelClassifier1D, VocabParallelCrossEntropyLoss1D, VocabParallelEmbedding1D)
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES, check_equal, IMG_SIZE
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
|
||||
|
||||
|
||||
def check_linear_col():
|
||||
@@ -144,3 +146,351 @@ def check_linear_row():
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('linear_row backward: pass')
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = C_master.clone()
|
||||
check_equal(out, C)
|
||||
print_rank_0('embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('embed backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = C_master.clone()
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel embed backward: pass')
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
env.parallel_input_1d = False
|
||||
parallel_input_1d = env.parallel_input_1d
|
||||
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True)
|
||||
layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
W_master = layer_master.weight.data
|
||||
dist.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
|
||||
layer.weight.data.copy_(W)
|
||||
B_master = layer_master.bias.data
|
||||
dist.broadcast(B_master, src=0)
|
||||
B = B_master.clone()
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(A_master, src=0)
|
||||
if parallel_input_1d:
|
||||
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
|
||||
A = A.clone()
|
||||
else:
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = C_master.clone()
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
if parallel_input_1d:
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = layer_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = layer_master.bias.grad
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
W_master = layer_master.weight.data
|
||||
dist.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=0)[i]
|
||||
layer.weight.data.copy_(W)
|
||||
B_master = layer_master.bias.data
|
||||
dist.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = layer_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = layer_master.bias.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
env.parallel_input_1d = False
|
||||
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
|
||||
layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = C_master.clone()
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
|
||||
print_rank_0('classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
env.parallel_input_1d = False
|
||||
layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
|
||||
layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
criterion = VocabParallelCrossEntropyLoss1D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=-1)[i]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('vocab parallel loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('vocab parallel loss backward: pass')
|
||||
|
@@ -9,6 +9,7 @@ SEQ_LENGTH = 8
|
||||
IMG_SIZE = 16
|
||||
HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 8
|
||||
VOCAB_SIZE = 16
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True
|
||||
|
@@ -7,6 +7,7 @@ import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
|
||||
@@ -24,6 +25,7 @@ CONFIG = dict(
|
||||
|
||||
|
||||
def check_layer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@@ -33,6 +35,13 @@ def check_layer(rank, world_size, port):
|
||||
|
||||
check_linear_col()
|
||||
check_linear_row()
|
||||
check_embed()
|
||||
check_vocab_parallel_embed()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_vocab_parallel_loss()
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import Linear2D, LayerNorm2D, Classifier2D
|
||||
from colossalai.nn import (Classifier2D, CrossEntropyLoss2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D,
|
||||
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2D,
|
||||
VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D)
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
from .common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal, NUM_CLASSES
|
||||
|
||||
from .common import (BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal)
|
||||
|
||||
|
||||
def check_linear():
|
||||
@@ -57,7 +58,6 @@ def check_linear():
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
|
||||
# print(f'Rank {gpc.get_global_rank()} A:\n{A}\nRank {gpc.get_global_rank()} W:\n{W}\nRank {gpc.get_global_rank()} b:\n{B}\nRank {gpc.get_global_rank()} C:\n{C}\nRank {gpc.get_global_rank()} out:\n{out}')
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear forward: pass')
|
||||
|
||||
@@ -90,84 +90,6 @@ def check_linear():
|
||||
print_rank_0('linear backward: pass')
|
||||
|
||||
|
||||
def check_classifier():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = NUM_CLASSES
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
|
||||
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=-1)[j]
|
||||
W = torch.chunk(W, DEPTH, dim=-1)[i]
|
||||
W = W.clone()
|
||||
layer.weight.data.copy_(W)
|
||||
# W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE,)
|
||||
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
# B = torch.chunk(B_master, DEPTH, dim=0)[j]
|
||||
B = B_master.clone()
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
# C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
# if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('classifier backward: pass')
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
@@ -219,6 +141,497 @@ def check_layernorm():
|
||||
print_rank_0('layer norm backward: pass')
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('embed backward: pass')
|
||||
|
||||
|
||||
def check_patch_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layer = PatchEmbedding2D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
torch.nn.init.ones_(layer.pos_embed)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
proj_weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(proj_weight_master, src=0)
|
||||
proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[j]
|
||||
proj_weight = torch.chunk(proj_weight, DEPTH, dim=0)[i]
|
||||
layer.weight.data.copy_(proj_weight)
|
||||
proj_bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(proj_bias_master, src=0)
|
||||
proj_bias = torch.chunk(proj_bias_master, DEPTH, dim=0)[j]
|
||||
proj_bias = torch.chunk(proj_bias, DEPTH, dim=0)[i]
|
||||
layer.bias.data.copy_(proj_bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('patch embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[j]
|
||||
cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(cls_grad, layer.cls_token.grad)
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[j]
|
||||
pos_grad = torch.chunk(pos_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(pos_grad, layer.pos_embed.grad)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, layer.weight.grad)
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
||||
check_equal(bias_grad, layer.bias.grad)
|
||||
print_rank_0('patch embed backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel embed backward: pass')
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = NUM_CLASSES
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
|
||||
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=-1)[j]
|
||||
W = torch.chunk(W, DEPTH, dim=-1)[i]
|
||||
W = W.clone()
|
||||
layer.weight.data.copy_(W)
|
||||
# W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE, )
|
||||
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
# B = torch.chunk(B_master, DEPTH, dim=0)[j]
|
||||
B = B_master.clone()
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
# C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
# if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer = layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, DEPTH)[j]
|
||||
bias = torch.chunk(bias, DEPTH)[i]
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = layer_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = layer_master.bias.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH)[i]
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = Classifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
print_rank_0('classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
criterion = CrossEntropyLoss2D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('cross entropy loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('cross entropy loss backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
criterion = VocabParallelCrossEntropyLoss2D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
out = torch.chunk(out, DEPTH, dim=-1)[j]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('vocab parallel cross entropy loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('vocab parallel cross entropy loss backward: pass')
|
||||
|
||||
|
||||
# def check_attention():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
@@ -257,7 +670,6 @@ def check_layernorm():
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('self attention backward: pass')
|
||||
|
||||
|
||||
# def check_mlp():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
@@ -291,7 +703,6 @@ def check_layernorm():
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('mlp backward: pass')
|
||||
|
||||
|
||||
# def check_transformerlayer():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
|
@@ -8,6 +8,9 @@ BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 8
|
||||
HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 8
|
||||
VOCAB_SIZE = 16
|
||||
IMG_SIZE = 16
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) == True
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
@@ -8,20 +8,17 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from checks_2d.check_layer_2d import *
|
||||
from checks_2d.check_operation_2d import *
|
||||
from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
|
||||
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
|
||||
check_vocab_parallel_classifier_given_embed_weight,
|
||||
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
|
||||
check_vocab_parallel_loss)
|
||||
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(
|
||||
size=4,
|
||||
mode='2d'
|
||||
)
|
||||
),
|
||||
)
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')), )
|
||||
|
||||
|
||||
def check_operations():
|
||||
@@ -33,16 +30,24 @@ def check_operations():
|
||||
def check_layer():
|
||||
check_linear()
|
||||
check_layernorm()
|
||||
check_classifier()
|
||||
check_embed()
|
||||
check_patch_embed()
|
||||
check_vocab_parallel_embed()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_loss()
|
||||
check_vocab_parallel_loss()
|
||||
|
||||
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# check_operations()
|
||||
check_layer()
|
||||
gpc.destroy()
|
||||
|
@@ -1,11 +1,12 @@
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import Linear2p5D, LayerNorm2p5D, Classifier2p5D
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import print_rank_0
|
||||
from colossalai.nn import (Classifier2p5D, CrossEntropyLoss2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D,
|
||||
PatchEmbedding2p5D, VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2p5D,
|
||||
VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D)
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
from torch.nn import Parameter
|
||||
|
||||
from .common import *
|
||||
|
||||
|
||||
@@ -19,11 +20,7 @@ def check_linear():
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = Linear2p5D(
|
||||
INPUT_SIZE,
|
||||
OUTPUT_SIZE,
|
||||
dtype=dtype,
|
||||
skip_bias_add=False)
|
||||
layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
@@ -94,86 +91,6 @@ def check_linear():
|
||||
print_rank_0('linear backward: pass')
|
||||
|
||||
|
||||
def check_classifier():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = NUM_CLASSES
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
|
||||
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
# W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
|
||||
W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
|
||||
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i]
|
||||
W = W.clone()
|
||||
layer.weight.data.copy_(W)
|
||||
# W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE,)
|
||||
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
# B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
|
||||
B = B_master.clone()
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
# C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
# B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
|
||||
# if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('classifier backward: pass')
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
@@ -184,9 +101,7 @@ def check_layernorm():
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layernorm = LayerNorm2p5D(
|
||||
INPUT_SIZE,
|
||||
dtype=dtype)
|
||||
layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
@@ -228,6 +143,500 @@ def check_layernorm():
|
||||
print_rank_0('layer norm backward: pass')
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('embed backward: pass')
|
||||
|
||||
|
||||
def check_patch_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
torch.nn.init.ones_(layer.pos_embed)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
proj_weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(proj_weight_master, src=0)
|
||||
proj_weight = torch.chunk(proj_weight_master, TESSERACT_DIM, dim=0)[j]
|
||||
proj_weight = torch.chunk(proj_weight, TESSERACT_DIM, dim=0)[i]
|
||||
layer.weight.data.copy_(proj_weight)
|
||||
proj_bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(proj_bias_master, src=0)
|
||||
proj_bias = torch.chunk(proj_bias_master, TESSERACT_DIM, dim=0)[j]
|
||||
proj_bias = torch.chunk(proj_bias, TESSERACT_DIM, dim=0)[i]
|
||||
layer.bias.data.copy_(proj_bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('patch embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, TESSERACT_DIM, dim=-1)[j]
|
||||
cls_grad = torch.chunk(cls_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(cls_grad, layer.cls_token.grad)
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, TESSERACT_DIM, dim=-1)[j]
|
||||
pos_grad = torch.chunk(pos_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(pos_grad, layer.pos_embed.grad)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(B_grad, layer.weight.grad)
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j]
|
||||
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i]
|
||||
check_equal(bias_grad, layer.bias.grad)
|
||||
print_rank_0('patch embed backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel embed backward: pass')
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = NUM_CLASSES
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
|
||||
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
# W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
|
||||
W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
|
||||
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i]
|
||||
W = W.clone()
|
||||
layer.weight.data.copy_(W)
|
||||
# W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE, )
|
||||
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
# B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
|
||||
B = B_master.clone()
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
# C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
# B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
|
||||
# if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer = layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=0)[i]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[j]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, TESSERACT_DIM)[j]
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = layer_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = layer_master.bias.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j]
|
||||
if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = Classifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
print_rank_0('classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
criterion = CrossEntropyLoss2p5D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('cross entropy loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('cross entropy loss backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
criterion = VocabParallelCrossEntropyLoss2p5D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
|
||||
out = torch.chunk(out, TESSERACT_DIM, dim=-1)[j]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('vocab parallel cross entropy loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('vocab parallel cross entropy loss backward: pass')
|
||||
|
||||
|
||||
# def check_attention():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
@@ -267,7 +676,6 @@ def check_layernorm():
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('self attention backward: pass')
|
||||
|
||||
|
||||
# def check_mlp():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
@@ -304,7 +712,6 @@ def check_layernorm():
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('mlp backward: pass')
|
||||
|
||||
|
||||
# def check_transformerlayer():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
@@ -344,4 +751,4 @@ def check_layernorm():
|
||||
|
||||
# out.backward(grad)
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('transformerlayer backward: pass')
|
||||
# print_rank_0('transformerlayer backward: pass')
|
||||
|
@@ -5,8 +5,10 @@ TESSERACT_DEP = 2
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 8
|
||||
HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 3
|
||||
NUM_CLASSES = 8
|
||||
VOCAB_SIZE = 16
|
||||
IMG_SIZE = 16
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
|
||||
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2)
|
@@ -5,10 +5,10 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from checks_2p5d.check_layer_2p5d import (check_classifier, check_layernorm,
|
||||
check_linear)
|
||||
from checks_2p5d.check_layer_2p5d import *
|
||||
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
|
||||
|
||||
CONFIG = dict(
|
||||
@@ -28,10 +28,19 @@ def check_operations():
|
||||
def check_layer():
|
||||
check_linear()
|
||||
check_layernorm()
|
||||
check_classifier()
|
||||
check_embed()
|
||||
check_patch_embed()
|
||||
check_vocab_parallel_embed()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_loss()
|
||||
check_vocab_parallel_loss()
|
||||
|
||||
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@@ -39,6 +48,9 @@ def check_layer_and_operation(rank, world_size, port):
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
check_operations()
|
||||
check_layer()
|
||||
gpc.destroy()
|
||||
|
@@ -3,16 +3,17 @@
|
||||
|
||||
import time
|
||||
|
||||
from colossalai.constants import (INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D)
|
||||
import torch
|
||||
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.core import global_context
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VanillaClassifier,
|
||||
VanillaPatchEmbedding)
|
||||
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D,
|
||||
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D,
|
||||
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D)
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
from .common import *
|
||||
import torch
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
|
||||
|
||||
|
||||
def check_linear():
|
||||
@@ -27,9 +28,9 @@ def check_linear():
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = A_rank = global_context.get_local_rank(input_parallel_mode)
|
||||
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = C_rank = global_context.get_local_rank(output_parallel_mode)
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, bias=True)
|
||||
layer = layer.to(device)
|
||||
@@ -112,9 +113,9 @@ def check_layernorm():
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = A_rank = global_context.get_local_rank(input_parallel_mode)
|
||||
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = C_rank = global_context.get_local_rank(output_parallel_mode)
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6, dtype=dtype)
|
||||
norm = norm.to(device)
|
||||
@@ -186,7 +187,7 @@ def check_layernorm():
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_classifier():
|
||||
def check_classifier_no_given_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
@@ -197,9 +198,9 @@ def check_classifier():
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = A_rank = global_context.get_local_rank(input_parallel_mode)
|
||||
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = C_rank = global_context.get_local_rank(output_parallel_mode)
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, dtype=dtype, bias=True)
|
||||
layer = layer.to(device)
|
||||
@@ -229,14 +230,14 @@ def check_classifier():
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
|
||||
logger)
|
||||
'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} head forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
@@ -249,7 +250,7 @@ def check_classifier():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
@@ -257,23 +258,275 @@ def check_classifier():
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} head backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
|
||||
logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} head backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad, layer.weight.grad)))
|
||||
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} head backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
|
||||
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, layer.weight.grad is None))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
logger.info('Rank {} head backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_embed():
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer = layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, DEPTH)[j]
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
embed = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = Classifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(embed(A))
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, embed.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
|
||||
rank, embed.weight.grad is None))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(embed(A))
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad,
|
||||
embed.weight.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_patch_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
@@ -283,9 +536,9 @@ def check_embed():
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = A_rank = global_context.get_local_rank(input_parallel_mode)
|
||||
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = C_rank = global_context.get_local_rank(output_parallel_mode)
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
@@ -310,18 +563,99 @@ def check_embed():
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start), logger)
|
||||
'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start), logger)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank,
|
||||
check_equal(pos_grad, layer.pos_embed.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank,
|
||||
check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
@@ -329,7 +663,7 @@ def check_embed():
|
||||
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
@@ -339,30 +673,88 @@ def check_embed():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, layer.weight.grad is None))
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(rank, check_equal(pos_grad, layer.pos_embed.grad)))
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
grad = grad.clone()
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, check_equal(B_grad,
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad,
|
||||
layer.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(rank, layer.weight.grad is None))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} embed backward (proj_bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -375,11 +767,9 @@ def check_loss():
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = A_rank = global_context.get_local_rank(input_parallel_mode)
|
||||
i = B_rank = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = C_rank = global_context.get_local_rank(output_parallel_mode)
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
|
||||
criterion = CrossEntropyLoss3D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
@@ -397,24 +787,79 @@ def check_loss():
|
||||
fwd_start = time.time()
|
||||
loss = criterion(out, target_master)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
|
||||
logger)
|
||||
logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape),
|
||||
fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
logger.info('Rank {} CrossEntropyLoss forward: {}'.format(rank, check_equal(loss, loss_master)))
|
||||
logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
|
||||
|
||||
bwd_start = time.time()
|
||||
loss.backward()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
loss_master.backward()
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} CrossEntropyLoss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
|
||||
logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
criterion = VocabParallelCrossEntropyLoss3D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
out = torch.chunk(out, DEPTH, dim=-1)[k]
|
||||
out = torch.chunk(out, DEPTH, dim=0)[j]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
loss = criterion(out, target_master)
|
||||
fwd_end = time.time()
|
||||
logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
|
||||
|
||||
bwd_start = time.time()
|
||||
loss.backward()
|
||||
bwd_end = time.time()
|
||||
logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
loss_master.backward()
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
@@ -10,6 +10,7 @@ HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 8
|
||||
NUM_BLOCKS = 2
|
||||
IMG_SIZE = 16
|
||||
VOCAB_SIZE = 16
|
||||
|
||||
def check_equal(A, B):
|
||||
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
@@ -7,9 +7,14 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from checks_3d.check_layer_3d import *
|
||||
from checks_3d.check_layer_3d import (check_classifier_given_embed_weight, check_classifier_no_given_weight,
|
||||
check_embed, check_layernorm, check_linear, check_loss, check_patch_embed,
|
||||
check_vocab_parallel_classifier_given_embed_weight,
|
||||
check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed,
|
||||
check_vocab_parallel_loss)
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
@@ -23,13 +28,23 @@ CONFIG = dict(
|
||||
def check_layer():
|
||||
check_linear()
|
||||
check_layernorm()
|
||||
check_classifier()
|
||||
# check_embed()
|
||||
# check_loss()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_embed()
|
||||
check_patch_embed()
|
||||
check_vocab_parallel_embed()
|
||||
check_loss()
|
||||
check_vocab_parallel_loss()
|
||||
|
||||
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
check_layer()
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
Reference in New Issue
Block a user