mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[legacy] move communication and nn to legacy and refactor logger (#4671)
* [legacy] move communication to legacy (#4640) * [legacy] refactor logger and clean up legacy codes (#4654) * [legacy] make logger independent to gpc * [legacy] make optim independent to registry * [legacy] move test engine to legacy * [legacy] move nn to legacy (#4656) * [legacy] move nn to legacy * [checkpointio] fix save hf config * [test] remove useledd rpc pp test * [legacy] fix nn init * [example] skip tutorial hybriad parallel example * [devops] test doc check * [devops] test doc check
This commit is contained in:
@@ -0,0 +1,552 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.legacy.nn import (
|
||||
Classifier1D,
|
||||
Embedding1D,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
VanillaClassifier,
|
||||
VocabParallelClassifier1D,
|
||||
VocabParallelCrossEntropyLoss1D,
|
||||
VocabParallelEmbedding1D,
|
||||
)
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
|
||||
|
||||
|
||||
def check_linear_col():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
|
||||
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=0)[i]
|
||||
W = W.clone()
|
||||
W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
layer.weight = Parameter(W)
|
||||
layer.bias = Parameter(B)
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear_col forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('linear_col backward: pass')
|
||||
|
||||
|
||||
def check_linear_row():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
|
||||
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
|
||||
W = W.clone()
|
||||
W.requires_grad = True
|
||||
|
||||
B_shape = (INPUT_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(B_master, src=0)
|
||||
B = B_master.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
layer.weight = Parameter(W)
|
||||
layer.bias = Parameter(B)
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = C_master.clone()
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear_row forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('linear_row backward: pass')
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = C_master.clone()
|
||||
check_equal(out, C)
|
||||
print_rank_0('embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('embed backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = C_master.clone()
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel embed backward: pass')
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
env.parallel_input_1d = False
|
||||
parallel_input_1d = env.parallel_input_1d
|
||||
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True)
|
||||
layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
W_master = layer_master.weight.data
|
||||
dist.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
|
||||
layer.weight.data.copy_(W)
|
||||
B_master = layer_master.bias.data
|
||||
dist.broadcast(B_master, src=0)
|
||||
B = B_master.clone()
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(A_master, src=0)
|
||||
if parallel_input_1d:
|
||||
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
|
||||
A = A.clone()
|
||||
else:
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = C_master.clone()
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
if parallel_input_1d:
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = layer_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = layer_master.bias.grad
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
W_master = layer_master.weight.data
|
||||
dist.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=0)[i]
|
||||
layer.weight.data.copy_(W)
|
||||
B_master = layer_master.bias.data
|
||||
dist.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = layer_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = layer_master.bias.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
env.parallel_input_1d = False
|
||||
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
|
||||
layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = C_master.clone()
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
|
||||
print_rank_0('classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
env.parallel_input_1d = False
|
||||
layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
|
||||
layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
criterion = VocabParallelCrossEntropyLoss1D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=-1)[i]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('vocab parallel loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('vocab parallel loss backward: pass')
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def check_linear_row_stream_inference():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
stream_chunk_num = 4
|
||||
assert HIDDEN_SIZE % stream_chunk_num == 0
|
||||
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
|
||||
A = A.clone()
|
||||
|
||||
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
|
||||
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
|
||||
W = W.clone()
|
||||
|
||||
B_shape = (INPUT_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
dist.broadcast(B_master, src=0)
|
||||
B = B_master.clone()
|
||||
|
||||
layer.weight = Parameter(W)
|
||||
layer.bias = Parameter(B)
|
||||
layer.chunk_weight()
|
||||
layer.eval()
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
W_master = W_master.clone()
|
||||
B_master = B_master.clone()
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = C_master.clone()
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear_row forward: pass')
|
16
tests/test_legacy/test_layers/test_1d/checks_1d/common.py
Normal file
16
tests/test_legacy/test_layers/test_1d/checks_1d/common.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
DEPTH = 4
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 8
|
||||
IMG_SIZE = 16
|
||||
HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 8
|
||||
VOCAB_SIZE = 16
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True
|
43
tests/test_legacy/test_layers/test_1d/test_1d.py
Normal file
43
tests/test_legacy/test_layers/test_1d/test_1d.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from checks_1d.check_layer_1d import *
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
|
||||
|
||||
|
||||
def check_layer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
check_linear_col()
|
||||
check_linear_row()
|
||||
check_embed()
|
||||
check_vocab_parallel_embed()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_vocab_parallel_loss()
|
||||
|
||||
check_linear_row_stream_inference()
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_1d():
|
||||
spawn(check_layer, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_1d()
|
@@ -0,0 +1,752 @@
|
||||
import torch
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.nn import (
|
||||
Classifier2D,
|
||||
CrossEntropyLoss2D,
|
||||
Embedding2D,
|
||||
LayerNorm2D,
|
||||
Linear2D,
|
||||
PatchEmbedding2D,
|
||||
VanillaClassifier,
|
||||
VanillaPatchEmbedding,
|
||||
VocabParallelClassifier2D,
|
||||
VocabParallelCrossEntropyLoss2D,
|
||||
VocabParallelEmbedding2D,
|
||||
)
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
|
||||
|
||||
|
||||
def check_linear():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layer = Linear2D(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
|
||||
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=0)[i]
|
||||
W = torch.chunk(W, DEPTH, dim=-1)[j]
|
||||
W = W.clone()
|
||||
W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=-1)[j]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[i]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
layer.weight.data.copy_(W)
|
||||
layer.bias.data.copy_(B)
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master) + B_master
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
# if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('linear backward: pass')
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
EPS = 1e-12
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layernorm = LayerNorm2D(INPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layernorm(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
E_master = torch.sum(A_master, dim=-1, keepdim=True)
|
||||
E_master /= INPUT_SIZE
|
||||
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
|
||||
V_master /= INPUT_SIZE
|
||||
V_master = V_master - E_master * E_master
|
||||
V_master = 1.0 / torch.sqrt(V_master + EPS)
|
||||
C_master = (A_master - E_master) * V_master
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('layer norm forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
print_rank_0('layer norm backward: pass')
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('embed backward: pass')
|
||||
|
||||
|
||||
def check_patch_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layer = PatchEmbedding2D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
torch.nn.init.ones_(layer.pos_embed)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
proj_weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(proj_weight_master, src=0)
|
||||
proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[j]
|
||||
proj_weight = torch.chunk(proj_weight, DEPTH, dim=0)[i]
|
||||
layer.weight.data.copy_(proj_weight)
|
||||
proj_bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(proj_bias_master, src=0)
|
||||
proj_bias = torch.chunk(proj_bias_master, DEPTH, dim=0)[j]
|
||||
proj_bias = torch.chunk(proj_bias, DEPTH, dim=0)[i]
|
||||
layer.bias.data.copy_(proj_bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('patch embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[j]
|
||||
cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(cls_grad, layer.cls_token.grad)
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[j]
|
||||
pos_grad = torch.chunk(pos_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(pos_grad, layer.pos_embed.grad)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, layer.weight.grad)
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
||||
check_equal(bias_grad, layer.bias.grad)
|
||||
print_rank_0('patch embed backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel embed backward: pass')
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = NUM_CLASSES
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
|
||||
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, DEPTH, dim=-1)[j]
|
||||
W = torch.chunk(W, DEPTH, dim=-1)[i]
|
||||
W = W.clone()
|
||||
layer.weight.data.copy_(W)
|
||||
# W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE,)
|
||||
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
# B = torch.chunk(B_master, DEPTH, dim=0)[j]
|
||||
B = B_master.clone()
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
# C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
# if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer = layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, DEPTH)[j]
|
||||
bias = torch.chunk(bias, DEPTH)[i]
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = layer_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = layer_master.bias.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH)[i]
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = Classifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
print_rank_0('classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
criterion = CrossEntropyLoss2D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('cross entropy loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('cross entropy loss backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
criterion = VocabParallelCrossEntropyLoss2D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
out = torch.chunk(out, DEPTH, dim=-1)[j]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('vocab parallel cross entropy loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('vocab parallel cross entropy loss backward: pass')
|
||||
|
||||
|
||||
# def check_attention():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
# NUM_ATTENTION_HEADS = 2
|
||||
|
||||
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
# layer = TransformerSelfAttention2D(
|
||||
# HIDDEN_SIZE,
|
||||
# NUM_ATTENTION_HEADS,
|
||||
# attention_dropout_prob=0.5,
|
||||
# hidden_dropout_prob=0.5,
|
||||
# )
|
||||
|
||||
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
# torch.distributed.broadcast(A_master, src=0)
|
||||
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
# A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
# A = A.clone()
|
||||
# A.requires_grad = True
|
||||
|
||||
# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
|
||||
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
||||
|
||||
# out = layer(A, attention_mask)
|
||||
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
|
||||
# print_rank_0('self attention forward: pass')
|
||||
|
||||
# grad_shape = out.shape
|
||||
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
# out.backward(grad)
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('self attention backward: pass')
|
||||
|
||||
# def check_mlp():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
# layer = TransformerMLP2D(
|
||||
# HIDDEN_SIZE,
|
||||
# dropout_prob=0.5,
|
||||
# act_func='gelu',
|
||||
# )
|
||||
|
||||
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
# torch.distributed.broadcast(A_master, src=0)
|
||||
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
# A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
# A = A.clone()
|
||||
# A.requires_grad = True
|
||||
|
||||
# out = layer(A)
|
||||
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
|
||||
# print_rank_0('mlp forward: pass')
|
||||
|
||||
# grad_shape = out.shape
|
||||
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
# out.backward(grad)
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('mlp backward: pass')
|
||||
|
||||
# def check_transformerlayer():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
# NUM_ATTENTION_HEADS = 2
|
||||
|
||||
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
# layer = TransformerLayer2D(HIDDEN_SIZE,
|
||||
# NUM_ATTENTION_HEADS,
|
||||
# act_func='gelu',
|
||||
# attention_dropout_prob=0.5,
|
||||
# hidden_dropout_prob=0.5)
|
||||
|
||||
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
# torch.distributed.broadcast(A_master, src=0)
|
||||
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
# A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
# A = A.clone()
|
||||
# A.requires_grad = True
|
||||
|
||||
# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
|
||||
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
||||
|
||||
# out = layer(A, attention_mask)
|
||||
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
|
||||
# print_rank_0('transformerlayer forward: pass')
|
||||
|
||||
# grad_shape = out.shape
|
||||
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
# out.backward(grad)
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('transformerlayer backward: pass')
|
@@ -0,0 +1,213 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal
|
||||
|
||||
|
||||
def check_AB():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
dtype = torch.float
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH)
|
||||
|
||||
out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
|
||||
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, B_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
# check forward correctness
|
||||
check_equal(out, C)
|
||||
print_rank_0('AB forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
# check backward correctness
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
# check backward correctness
|
||||
check_equal(B_grad, B.grad)
|
||||
print_rank_0('AB backward: pass')
|
||||
|
||||
|
||||
def check_ABT():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
dtype = torch.float
|
||||
device = get_current_device()
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(C_master, src=0)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = C.clone()
|
||||
C.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank,
|
||||
pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
C_master = C_master.clone()
|
||||
C_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
check_equal(out, A)
|
||||
print_rank_0('ABT forward: pass')
|
||||
|
||||
grad_shape = A_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
|
||||
# backward
|
||||
out.backward(grad)
|
||||
|
||||
A_master.backward(grad_master)
|
||||
C_grad = C_master.grad
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(C_grad, C.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(B_grad, B.grad)
|
||||
print_rank_0('ABT backward: pass')
|
||||
|
||||
|
||||
def check_ATB():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
device = get_current_device()
|
||||
dtype = torch.float
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(C_master, src=0)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = C.clone()
|
||||
C.requires_grad = True
|
||||
|
||||
out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j,
|
||||
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank,
|
||||
pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = C_master.clone()
|
||||
C_master.requires_grad = True
|
||||
B_master = torch.matmul(
|
||||
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]))
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||||
check_equal(out, B)
|
||||
print_rank_0('ATB forward: pass')
|
||||
|
||||
grad_shape = B_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
B_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
C_grad = C_master.grad
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
|
||||
check_equal(C_grad, C.grad)
|
||||
print_rank_0('ATB backward: pass')
|
16
tests/test_legacy/test_layers/test_2d/checks_2d/common.py
Normal file
16
tests/test_legacy/test_layers/test_2d/checks_2d/common.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
DEPTH = 2
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 8
|
||||
HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 8
|
||||
VOCAB_SIZE = 16
|
||||
IMG_SIZE = 16
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
69
tests/test_legacy/test_layers/test_2d/test_2d.py
Normal file
69
tests/test_legacy/test_layers/test_2d/test_2d.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from checks_2d.check_layer_2d import (
|
||||
check_classifier_given_embed_weight,
|
||||
check_classifier_no_given_weight,
|
||||
check_embed,
|
||||
check_layernorm,
|
||||
check_linear,
|
||||
check_loss,
|
||||
check_patch_embed,
|
||||
check_vocab_parallel_classifier_given_embed_weight,
|
||||
check_vocab_parallel_classifier_no_given_weight,
|
||||
check_vocab_parallel_embed,
|
||||
check_vocab_parallel_loss,
|
||||
)
|
||||
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),)
|
||||
|
||||
|
||||
def check_operations():
|
||||
check_AB()
|
||||
check_ABT()
|
||||
check_ATB()
|
||||
|
||||
|
||||
def check_layer():
|
||||
check_linear()
|
||||
check_layernorm()
|
||||
check_embed()
|
||||
check_patch_embed()
|
||||
check_vocab_parallel_embed()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_loss()
|
||||
check_vocab_parallel_loss()
|
||||
|
||||
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# check_operations()
|
||||
check_layer()
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2d():
|
||||
spawn(check_layer_and_operation, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_2d()
|
@@ -0,0 +1,765 @@
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.nn import (
|
||||
Classifier2p5D,
|
||||
CrossEntropyLoss2p5D,
|
||||
Embedding2p5D,
|
||||
LayerNorm2p5D,
|
||||
Linear2p5D,
|
||||
PatchEmbedding2p5D,
|
||||
VanillaClassifier,
|
||||
VanillaPatchEmbedding,
|
||||
VocabParallelClassifier2p5D,
|
||||
VocabParallelCrossEntropyLoss2p5D,
|
||||
VocabParallelEmbedding2p5D,
|
||||
)
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
from .common import *
|
||||
|
||||
|
||||
def check_linear():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
|
||||
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, TESSERACT_DIM, dim=0)[i]
|
||||
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[j]
|
||||
W = W.clone()
|
||||
W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
layer.weight = Parameter(W)
|
||||
layer.bias = Parameter(B)
|
||||
out = layer(A)
|
||||
bias = layer.bias
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master) + B_master
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
|
||||
if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('linear backward: pass')
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
EPS = 1e-12
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layernorm(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
E_master = torch.sum(A_master, dim=-1, keepdim=True)
|
||||
E_master /= INPUT_SIZE
|
||||
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
|
||||
V_master /= INPUT_SIZE
|
||||
V_master = V_master - E_master * E_master
|
||||
V_master = 1.0 / torch.sqrt(V_master + EPS)
|
||||
C_master = (A_master - E_master) * V_master
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('layer norm forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
print_rank_0('layer norm backward: pass')
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('embed backward: pass')
|
||||
|
||||
|
||||
def check_patch_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
torch.nn.init.ones_(layer.pos_embed)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
proj_weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(proj_weight_master, src=0)
|
||||
proj_weight = torch.chunk(proj_weight_master, TESSERACT_DIM, dim=0)[j]
|
||||
proj_weight = torch.chunk(proj_weight, TESSERACT_DIM, dim=0)[i]
|
||||
layer.weight.data.copy_(proj_weight)
|
||||
proj_bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(proj_bias_master, src=0)
|
||||
proj_bias = torch.chunk(proj_bias_master, TESSERACT_DIM, dim=0)[j]
|
||||
proj_bias = torch.chunk(proj_bias, TESSERACT_DIM, dim=0)[i]
|
||||
layer.bias.data.copy_(proj_bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('patch embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, TESSERACT_DIM, dim=-1)[j]
|
||||
cls_grad = torch.chunk(cls_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(cls_grad, layer.cls_token.grad)
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, TESSERACT_DIM, dim=-1)[j]
|
||||
pos_grad = torch.chunk(pos_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(pos_grad, layer.pos_embed.grad)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(B_grad, layer.weight.grad)
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j]
|
||||
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i]
|
||||
check_equal(bias_grad, layer.bias.grad)
|
||||
print_rank_0('patch embed backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = embed(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = embed_master(A_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel embed forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(B_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel embed backward: pass')
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = NUM_CLASSES
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
||||
layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
|
||||
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
# W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
|
||||
W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
|
||||
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i]
|
||||
W = W.clone()
|
||||
layer.weight.data.copy_(W)
|
||||
# W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE,)
|
||||
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
# B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
|
||||
B = B_master.clone()
|
||||
layer.bias.data.copy_(B)
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
# C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
# B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
|
||||
# if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer = layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=0)[i]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[j]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, TESSERACT_DIM)[j]
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = layer_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = layer_master.bias.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j]
|
||||
if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = Classifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(out, C)
|
||||
print_rank_0('classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
print_rank_0('classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
|
||||
weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
out = layer(embed(A))
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, C)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
grad = grad.clone()
|
||||
out.backward(grad)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
W_grad = embed_master.weight.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(W_grad, embed.weight.grad)
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
|
||||
|
||||
|
||||
def check_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
criterion = CrossEntropyLoss2p5D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('cross entropy loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('cross entropy loss backward: pass')
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
criterion = VocabParallelCrossEntropyLoss2p5D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
|
||||
out = torch.chunk(out, TESSERACT_DIM, dim=-1)[j]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
loss = criterion(out, target_master)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
check_equal(loss, loss_master)
|
||||
print_rank_0('vocab parallel cross entropy loss forward: pass')
|
||||
|
||||
loss.backward()
|
||||
loss_master.backward()
|
||||
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out_grad, out.grad)
|
||||
print_rank_0('vocab parallel cross entropy loss backward: pass')
|
||||
|
||||
|
||||
# def check_attention():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
# NUM_ATTENTION_HEADS = 2
|
||||
|
||||
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
# layer = TransformerSelfAttention2p5D(
|
||||
# HIDDEN_SIZE, NUM_ATTENTION_HEADS,
|
||||
# attention_dropout_prob=0.5,
|
||||
# hidden_dropout_prob=0.5,
|
||||
# dtype=dtype,
|
||||
# )
|
||||
|
||||
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
# torch.distributed.broadcast(A_master, src=0)
|
||||
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
# A = A.clone()
|
||||
# A.requires_grad = True
|
||||
|
||||
# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
|
||||
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
||||
|
||||
# out = layer(A, attention_mask)
|
||||
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
|
||||
# print_rank_0('self attention forward: pass')
|
||||
|
||||
# grad_shape = out.shape
|
||||
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
# out.backward(grad)
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('self attention backward: pass')
|
||||
|
||||
# def check_mlp():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
# layer = TransformerMLP2p5D(
|
||||
# HIDDEN_SIZE,
|
||||
# mlp_ratio=1,
|
||||
# dropout_prob=0.5,
|
||||
# act_func='gelu',
|
||||
# dtype=dtype,
|
||||
# )
|
||||
|
||||
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
# torch.distributed.broadcast(A_master, src=0)
|
||||
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
# A = A.clone()
|
||||
# A.requires_grad = True
|
||||
|
||||
# out = layer(A)
|
||||
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
|
||||
# print_rank_0('mlp forward: pass')
|
||||
|
||||
# grad_shape = out.shape
|
||||
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
# out.backward(grad)
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('mlp backward: pass')
|
||||
|
||||
# def check_transformerlayer():
|
||||
# device = get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
# NUM_ATTENTION_HEADS = 2
|
||||
|
||||
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
# layer = TransformerLayer2p5D(
|
||||
# HIDDEN_SIZE,
|
||||
# NUM_ATTENTION_HEADS,
|
||||
# act_func='gelu',
|
||||
# attention_dropout_prob=0.5,
|
||||
# hidden_dropout_prob=0.5,
|
||||
# dtype=dtype,
|
||||
# )
|
||||
|
||||
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
# torch.distributed.broadcast(A_master, src=0)
|
||||
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
# A = A.clone()
|
||||
# A.requires_grad = True
|
||||
|
||||
# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
|
||||
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
||||
|
||||
# out = layer(A, attention_mask)
|
||||
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
|
||||
# print_rank_0('transformerlayer forward: pass')
|
||||
|
||||
# grad_shape = out.shape
|
||||
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
# out.backward(grad)
|
||||
# assert A.grad.shape == A.shape
|
||||
# print_rank_0('transformerlayer backward: pass')
|
@@ -0,0 +1,215 @@
|
||||
import torch
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
from .common import *
|
||||
|
||||
|
||||
def check_AB():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
dtype = torch.float
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
|
||||
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM)
|
||||
out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank,
|
||||
pipeline_parallel_size, tensor_parallel_size)
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, B_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
# check forward correctness
|
||||
check_equal(out, C)
|
||||
print_rank_0('AB forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
# check backward correctness
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
# check backward correctness
|
||||
check_equal(B_grad, B.grad)
|
||||
print_rank_0('AB backward: pass')
|
||||
|
||||
|
||||
def check_ABT():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
dtype = torch.float
|
||||
device = get_current_device()
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(C_master, src=0)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
C = C.clone()
|
||||
C.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
|
||||
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM,
|
||||
(BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k,
|
||||
ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank,
|
||||
pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
C_master = C_master.clone()
|
||||
C_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, A)
|
||||
print_rank_0('ABT forward: pass')
|
||||
|
||||
grad_shape = A_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
# backward
|
||||
out.backward(grad)
|
||||
|
||||
A_master.backward(grad_master)
|
||||
C_grad = C_master.grad
|
||||
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
|
||||
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(C_grad, C.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(B_grad, B.grad)
|
||||
print_rank_0('ABT backward: pass')
|
||||
|
||||
|
||||
def check_ATB():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
device = get_current_device()
|
||||
dtype = torch.float
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(C_master, src=0)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
C = C.clone()
|
||||
C.requires_grad = True
|
||||
|
||||
out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
|
||||
i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
|
||||
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size,
|
||||
tensor_parallel_size)
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = C_master.clone()
|
||||
C_master.requires_grad = True
|
||||
B_master = torch.matmul(
|
||||
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]))
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
|
||||
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, B)
|
||||
print_rank_0('ATB forward: pass')
|
||||
|
||||
grad_shape = B_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
B_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
C_grad = C_master.grad
|
||||
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
|
||||
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(C_grad, C.grad)
|
||||
print_rank_0('ATB backward: pass')
|
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
|
||||
TESSERACT_DIM = 2
|
||||
TESSERACT_DEP = 2
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 8
|
||||
HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 8
|
||||
VOCAB_SIZE = 16
|
||||
IMG_SIZE = 16
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2)
|
57
tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
Normal file
57
tests/test_legacy/test_layers/test_2p5d/test_2p5d.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
import torch
|
||||
from checks_2p5d.check_layer_2p5d import *
|
||||
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=4, mode='2.5d', depth=1),
|
||||
),)
|
||||
|
||||
|
||||
def check_operations():
|
||||
check_AB()
|
||||
check_ABT()
|
||||
check_ATB()
|
||||
|
||||
|
||||
def check_layer():
|
||||
check_linear()
|
||||
check_layernorm()
|
||||
check_embed()
|
||||
check_patch_embed()
|
||||
check_vocab_parallel_embed()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_classifier_given_embed_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_loss()
|
||||
check_vocab_parallel_loss()
|
||||
|
||||
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
check_operations()
|
||||
check_layer()
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2p5d():
|
||||
spawn(check_layer_and_operation, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_2p5d()
|
@@ -0,0 +1,875 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.core import global_context
|
||||
from colossalai.legacy.nn import (
|
||||
Classifier3D,
|
||||
CrossEntropyLoss3D,
|
||||
Embedding3D,
|
||||
LayerNorm3D,
|
||||
Linear3D,
|
||||
PatchEmbedding3D,
|
||||
VanillaClassifier,
|
||||
VanillaPatchEmbedding,
|
||||
VocabParallelClassifier3D,
|
||||
VocabParallelCrossEntropyLoss3D,
|
||||
VocabParallelEmbedding3D,
|
||||
)
|
||||
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
|
||||
|
||||
|
||||
def check_linear():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data.transpose(0, 1).contiguous()
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, DEPTH)[j]
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad.transpose(0, 1)
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6)
|
||||
norm = norm.to(device)
|
||||
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
|
||||
norm_master = norm_master.to(device)
|
||||
|
||||
weight_master = norm_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH)[k]
|
||||
norm.weight.data.copy_(weight)
|
||||
bias_master = norm_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, DEPTH)[k]
|
||||
norm.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = norm(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start), logger)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = norm_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
bias_grad = norm_master.weight.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad)))
|
||||
|
||||
bias_grad = norm_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
layer.bias.data.copy_(bias_master)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, layer.weight.grad is None))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, DEPTH)[j]
|
||||
layer.bias.data.copy_(bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
embed = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(dtype).to(device)
|
||||
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(dtype).to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = Classifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(dtype).to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(dtype).to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(embed(A))
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, embed.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
|
||||
rank, embed.weight.grad is None))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed = embed.to(device)
|
||||
|
||||
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
embed_master = embed_master.to(device)
|
||||
|
||||
weight_master = embed_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
embed.weight.data.copy_(weight)
|
||||
|
||||
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(embed(A))
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad,
|
||||
embed.weight.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_patch_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
torch.nn.init.ones_(layer.pos_embed)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
proj_weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(proj_weight_master, src=0)
|
||||
proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[k]
|
||||
layer.weight.data.copy_(proj_weight)
|
||||
proj_bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(proj_bias_master, src=0)
|
||||
proj_bias = torch.chunk(proj_bias_master, DEPTH)[k]
|
||||
layer.bias.data.copy_(proj_bias)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start), logger)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
grad = grad.clone()
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank,
|
||||
check_equal(pos_grad, layer.pos_embed.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank,
|
||||
check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
grad = grad.clone()
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
weight_master = layer_master.weight.data
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
|
||||
layer.weight.data.copy_(weight)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH)
|
||||
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
grad = grad.clone()
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad,
|
||||
layer.weight.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_loss():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
|
||||
criterion = CrossEntropyLoss3D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
out = torch.chunk(out, DEPTH, dim=0)[j]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
loss = criterion(out, target_master)
|
||||
fwd_end = time.time()
|
||||
logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape),
|
||||
fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
|
||||
|
||||
bwd_start = time.time()
|
||||
loss.backward()
|
||||
bwd_end = time.time()
|
||||
logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
loss_master.backward()
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
|
||||
|
||||
j = global_context.get_local_rank(input_parallel_mode)
|
||||
i = global_context.get_local_rank(weight_parallel_mode)
|
||||
k = global_context.get_local_rank(output_parallel_mode)
|
||||
|
||||
criterion = VocabParallelCrossEntropyLoss3D()
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
out = torch.chunk(out, DEPTH, dim=-1)[k]
|
||||
out = torch.chunk(out, DEPTH, dim=0)[j]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
loss = criterion(out, target_master)
|
||||
fwd_end = time.time()
|
||||
logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
|
||||
|
||||
bwd_start = time.time()
|
||||
loss.backward()
|
||||
bwd_end = time.time()
|
||||
logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
loss_master.backward()
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
19
tests/test_legacy/test_layers/test_3d/checks_3d/common.py
Normal file
19
tests/test_legacy/test_layers/test_3d/checks_3d/common.py
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
DEPTH = 2
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 8
|
||||
HIDDEN_SIZE = 8
|
||||
NUM_CLASSES = 8
|
||||
NUM_BLOCKS = 2
|
||||
IMG_SIZE = 16
|
||||
VOCAB_SIZE = 16
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
assert eq, f"\nA = {A}\nB = {B}"
|
||||
return eq
|
64
tests/test_legacy/test_layers/test_3d/test_3d.py
Normal file
64
tests/test_legacy/test_layers/test_3d/test_3d.py
Normal file
@@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
import pytest
|
||||
import torch
|
||||
from checks_3d.check_layer_3d import (
|
||||
check_classifier_no_given_weight,
|
||||
check_embed,
|
||||
check_layernorm,
|
||||
check_linear,
|
||||
check_loss,
|
||||
check_patch_embed,
|
||||
check_vocab_parallel_classifier_given_embed_weight,
|
||||
check_vocab_parallel_classifier_no_given_weight,
|
||||
check_vocab_parallel_embed,
|
||||
check_vocab_parallel_loss,
|
||||
)
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode='3d', size=8),
|
||||
),
|
||||
seed=42,
|
||||
)
|
||||
|
||||
|
||||
def check_layer():
|
||||
check_linear()
|
||||
check_layernorm()
|
||||
check_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_no_given_weight()
|
||||
check_vocab_parallel_classifier_given_embed_weight()
|
||||
check_embed()
|
||||
check_patch_embed()
|
||||
check_vocab_parallel_embed()
|
||||
check_loss()
|
||||
check_vocab_parallel_loss()
|
||||
|
||||
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
check_layer()
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_3d():
|
||||
spawn(check_layer_and_operation, 8)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_3d()
|
377
tests/test_legacy/test_layers/test_cache_embedding.py
Normal file
377
tests/test_legacy/test_layers/test_cache_embedding.py
Normal file
@@ -0,0 +1,377 @@
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.legacy.nn.parallel.layers import (
|
||||
CachedEmbeddingBag,
|
||||
CachedParamMgr,
|
||||
EvictionStrategy,
|
||||
ParallelCachedEmbeddingBag,
|
||||
ParallelCachedEmbeddingBagTablewise,
|
||||
TablewiseEmbeddingBagConfig,
|
||||
)
|
||||
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
NUM_EMBED, EMBED_DIM = 10, 8
|
||||
BATCH_SIZE = 8
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
To achieve reproducible results, it's necessary to fix random seeds
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def synthesize_1d_sparse_feature(
|
||||
batch_size,
|
||||
num_embed,
|
||||
device,
|
||||
):
|
||||
indices_in_batch = batch_size * 2
|
||||
indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long)
|
||||
offsets = torch.from_numpy(
|
||||
np.array([
|
||||
0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch
|
||||
])).to(device).long()
|
||||
return indices, offsets
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@clear_cache_before_run()
|
||||
def test_cachemgr():
|
||||
model = torch.nn.EmbeddingBag(10000, 128)
|
||||
# 10 chunks, 5 in cuda
|
||||
mgr = CachedParamMgr(model.weight.detach(), 5)
|
||||
assert mgr.cuda_row_num == 5
|
||||
|
||||
mgr._admit(1)
|
||||
assert not mgr._chunk_in_cuda(2)
|
||||
assert mgr._chunk_in_cuda(1)
|
||||
|
||||
# print(mgr.cached_chunk_table)
|
||||
mgr._admit(8)
|
||||
|
||||
# now 3 chunk is available
|
||||
assert mgr.cuda_available_chunk_num == 3
|
||||
|
||||
mgr._evict()
|
||||
assert mgr.cuda_available_chunk_num == 4
|
||||
|
||||
mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0))
|
||||
mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0))
|
||||
# print(mgr.cached_chunk_table)
|
||||
# mgr.print_comm_stats()
|
||||
|
||||
mgr.flush()
|
||||
assert mgr.cuda_available_chunk_num == 5
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_reorder_with_freq():
|
||||
num_embed = 100
|
||||
chunk_size = 1
|
||||
num_chunk = 5
|
||||
|
||||
idx_map = torch.randint(10000, size=(num_embed,))
|
||||
sorted_idx = torch.argsort(idx_map, descending=True).tolist()
|
||||
chunkid, offset_in_chunk = [], []
|
||||
for i in range(num_embed):
|
||||
idx = sorted_idx.index(i)
|
||||
chunkid.append(idx // chunk_size)
|
||||
offset_in_chunk.append(idx % chunk_size)
|
||||
|
||||
dev = torch.device('cuda')
|
||||
chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev)
|
||||
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev)
|
||||
|
||||
weight = torch.rand(num_embed, 2)
|
||||
mgr = CachedParamMgr(weight, num_chunk)
|
||||
|
||||
mgr.reorder(idx_map)
|
||||
|
||||
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev))
|
||||
mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor')
|
||||
mgr_offsets = torch.remainder(indices, chunk_size)
|
||||
assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
|
||||
assert torch.allclose(offset_in_chunk, mgr_offsets), \
|
||||
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('use_LFU', [True, False])
|
||||
def test_freq_aware_embed(use_LFU: bool):
|
||||
device = torch.device('cuda', 0)
|
||||
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
|
||||
model = CachedEmbeddingBag(NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
|
||||
ids_freq_mapping=None,
|
||||
evict_strategy=evict_strategy).to(device)
|
||||
|
||||
assert model.weight.shape[0] == NUM_EMBED
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
freeze=False)
|
||||
|
||||
assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device))
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
|
||||
|
||||
for i in range(5):
|
||||
indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device)
|
||||
res = model(indices, offsets)
|
||||
ref_res = ref_model(indices, offsets)
|
||||
assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}"
|
||||
|
||||
grad = torch.rand_like(res)
|
||||
# comparing gradient here is nontrivial
|
||||
res.backward(grad)
|
||||
ref_res.backward(grad)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
ref_optimizer.step()
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
model.cache_weight_mgr.flush()
|
||||
model_weight = model.weight.detach().to(device)
|
||||
ref_weight = ref_model.weight.detach()
|
||||
assert torch.allclose(model_weight, ref_weight), \
|
||||
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('init_freq', [True, False])
|
||||
def test_lfu_strategy(init_freq: bool):
|
||||
# minimal test to check behavior
|
||||
Bag = CachedEmbeddingBag(5,
|
||||
5,
|
||||
cache_ratio=3 / 5,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
||||
warmup_ratio=1.0,
|
||||
evict_strategy=EvictionStrategy.LFU)
|
||||
|
||||
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
|
||||
offsets = torch.tensor([0], device="cuda:0")
|
||||
|
||||
# prepare frequency learning info:
|
||||
Bag.forward(torch.tensor([2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||
|
||||
# check strategy
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||
Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1
|
||||
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
|
||||
Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3
|
||||
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
|
||||
Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit
|
||||
|
||||
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
|
||||
"LFU strategy behavior failed"
|
||||
|
||||
|
||||
def gather_tensor(tensor, rank, world_size):
|
||||
gather_list = []
|
||||
if rank == 0:
|
||||
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
|
||||
torch.distributed.gather(tensor, gather_list, dst=0)
|
||||
return gather_list
|
||||
|
||||
|
||||
def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
||||
if world_size != 2:
|
||||
return
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
|
||||
# initialize weight
|
||||
# 3 feature tables. idx: 0~5, 6~10, 11~17
|
||||
weight_tables = torch.rand(18, 5)
|
||||
weight_table1 = weight_tables[0:6]
|
||||
weight_table2 = weight_tables[6:11]
|
||||
weight_table3 = weight_tables[11:18]
|
||||
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
||||
embedding_bag_config_list.append(
|
||||
TablewiseEmbeddingBagConfig(num_embeddings=6,
|
||||
cuda_row_num=4,
|
||||
assigned_rank=0,
|
||||
initial_weight=weight_table1.clone().detach().cpu()))
|
||||
embedding_bag_config_list.append(
|
||||
TablewiseEmbeddingBagConfig(num_embeddings=5,
|
||||
cuda_row_num=4,
|
||||
assigned_rank=0,
|
||||
initial_weight=weight_table2.clone().detach().cpu()))
|
||||
embedding_bag_config_list.append(
|
||||
TablewiseEmbeddingBagConfig(num_embeddings=7,
|
||||
cuda_row_num=4,
|
||||
assigned_rank=1,
|
||||
initial_weight=weight_table3.clone().detach().cpu()))
|
||||
if rank == 0:
|
||||
_weight = torch.cat([weight_table1, weight_table2], 0)
|
||||
else:
|
||||
_weight = weight_table3
|
||||
model = ParallelCachedEmbeddingBagTablewise(
|
||||
embedding_bag_config_list,
|
||||
embedding_dim=5,
|
||||
_weight=_weight,
|
||||
include_last_offset=True,
|
||||
cache_ratio=0.5,
|
||||
buffer_size=0,
|
||||
evict_strategy=EvictionStrategy.LFU,
|
||||
)
|
||||
# explain
|
||||
'''
|
||||
batch feature 1 feature 2 feature 3
|
||||
input0 [1,2,3] [6,7] []
|
||||
input1 [] [9] [13,15]
|
||||
input2 [1,5] [6,8] [11]
|
||||
↑ ↑ ↑
|
||||
rank 0 rank 0 rank 1
|
||||
in KJT format
|
||||
'''
|
||||
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
||||
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
|
||||
already_split_along_rank=False)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
|
||||
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
|
||||
if rank == 0:
|
||||
fake_grad = rand_grad[0:2]
|
||||
else:
|
||||
fake_grad = rand_grad[2:]
|
||||
res.backward(fake_grad)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# check correctness
|
||||
if rank == 0:
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(),
|
||||
include_last_offset=True,
|
||||
freeze=False).to(device)
|
||||
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
|
||||
ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0)
|
||||
ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
||||
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
|
||||
ref_res.backward(ref_fake_grad)
|
||||
ref_optimizer.step()
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
model.cache_weight_mgr.flush()
|
||||
recover_weight = model.cache_weight_mgr.weight.to(device)
|
||||
ref_weight = ref_model.weight.detach()[:11]
|
||||
assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
|
||||
|
||||
|
||||
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
|
||||
num_embed = 100
|
||||
embed_dim = 16
|
||||
batch_size = 4
|
||||
|
||||
set_seed(4321)
|
||||
weight = torch.rand(num_embed, embed_dim)
|
||||
coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None)
|
||||
|
||||
# initialize the tensor spec for the embedding weight parameter,
|
||||
# which is an ColoParameter.
|
||||
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
model = ParallelCachedEmbeddingBag.from_pretrained(
|
||||
coloweight,
|
||||
include_last_offset=True,
|
||||
freeze=False,
|
||||
cache_ratio=batch_size * 2 / num_embed,
|
||||
)
|
||||
|
||||
assert model.cache_weight_mgr.weight.device.type == 'cpu'
|
||||
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
||||
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
|
||||
print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}")
|
||||
assert torch.allclose(weight_in_rank,
|
||||
model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}"
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
|
||||
if rank == 0:
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(),
|
||||
include_last_offset=True,
|
||||
freeze=False).to(device)
|
||||
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
|
||||
|
||||
set_seed(4321)
|
||||
for i in range(5):
|
||||
indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device)
|
||||
res = model(indices, offsets)
|
||||
|
||||
grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device)
|
||||
grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank]
|
||||
res.backward(grad_in_rank)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
res_list = gather_tensor(res.detach(), rank, world_size)
|
||||
|
||||
if rank == 0:
|
||||
ref_res = ref_model(indices, offsets)
|
||||
recover_res = torch.cat(res_list, dim=0)
|
||||
|
||||
assert torch.allclose(ref_res, recover_res)
|
||||
|
||||
ref_res.backward(grad)
|
||||
ref_optimizer.step()
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
model.cache_weight_mgr.flush()
|
||||
weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size)
|
||||
if rank == 0:
|
||||
recover_weight = torch.cat(weight_list, dim=1)
|
||||
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
# run_parallel_freq_aware_embed_columnwise(rank, world_size)
|
||||
run_parallel_freq_aware_embed_tablewise(rank, world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_parallel_freq_aware_embed(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_freq_aware_embed(True)
|
||||
test_parallel_freq_aware_embed(2)
|
||||
# test_lfu_strategy(False)
|
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.nn import TransformerSelfAttentionRing
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def check_selfattention():
|
||||
WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
SUB_SEQ_LENGTH = 8
|
||||
BATCH = 4
|
||||
HIDDEN_SIZE = 16
|
||||
|
||||
layer = TransformerSelfAttentionRing(16, 8, 8, 0.1)
|
||||
layer = layer.to(get_current_device())
|
||||
|
||||
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
|
||||
attention_mask = torch.randint(low=0, high=2,
|
||||
size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device())
|
||||
out = layer(hidden_states, attention_mask)
|
139
tests/test_legacy/test_layers/test_sequence/test_sequence.py
Normal file
139
tests/test_legacy/test_layers/test_sequence/test_sequence.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
|
||||
|
||||
|
||||
def check_ring_qk(rank, world_size):
|
||||
# params
|
||||
batch_size = 4
|
||||
num_heads = 4
|
||||
seq_length = 32
|
||||
attention_head_size = 32
|
||||
sub_seq_length = seq_length // world_size
|
||||
|
||||
# create master tensors
|
||||
q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
|
||||
k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
|
||||
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
|
||||
# create distributed tensors
|
||||
sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
|
||||
sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
|
||||
|
||||
# set autograd attributes
|
||||
q.requires_grad = True
|
||||
k.requires_grad = True
|
||||
q.retain_grad()
|
||||
k.retain_grad()
|
||||
sub_q.requires_grad = True
|
||||
sub_k.requires_grad = True
|
||||
sub_q.retain_grad()
|
||||
sub_k.retain_grad()
|
||||
|
||||
# compute master attention scores
|
||||
a = torch.matmul(q, k.transpose(2, 1))
|
||||
|
||||
# compute distributed attention scores
|
||||
ring_qk = RingQK.apply
|
||||
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
|
||||
|
||||
# check master and distributed attention scores
|
||||
sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
|
||||
|
||||
# run master backward
|
||||
a.retain_grad()
|
||||
a.mean().backward()
|
||||
|
||||
# run distributed backward
|
||||
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
torch.autograd.backward(sub_a, partial_master_a_grad)
|
||||
|
||||
# check master and distributed grads
|
||||
partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
|
||||
'attention score cannot match'
|
||||
|
||||
|
||||
def check_ring_av(rank, world_size):
|
||||
# params
|
||||
batch_size = 4
|
||||
num_heads = 4
|
||||
seq_length = 16
|
||||
attention_head_size = 32
|
||||
sub_seq_length = seq_length // world_size
|
||||
|
||||
# create master tensors
|
||||
a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda()
|
||||
v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
|
||||
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
|
||||
# create distributed tensors
|
||||
sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
|
||||
sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
|
||||
|
||||
# set autograd attributes
|
||||
a.requires_grad = True
|
||||
v.requires_grad = True
|
||||
a.retain_grad()
|
||||
v.retain_grad()
|
||||
sub_a.requires_grad = True
|
||||
sub_v.requires_grad = True
|
||||
sub_a.retain_grad()
|
||||
sub_v.retain_grad()
|
||||
|
||||
# compute master attention scores
|
||||
out = torch.matmul(a, v)
|
||||
|
||||
# compute distributed attention scores
|
||||
ring_av = RingAV.apply
|
||||
sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length)
|
||||
|
||||
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
|
||||
|
||||
# check master and distributed output
|
||||
sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
|
||||
|
||||
# # run master backward
|
||||
out.retain_grad()
|
||||
out.mean().backward()
|
||||
|
||||
# # run distributed backward
|
||||
partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
torch.autograd.backward(sub_out, partial_master_out_grad)
|
||||
|
||||
# # check master and distributed grads
|
||||
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
|
||||
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
|
||||
'attention output cannot match'
|
||||
|
||||
|
||||
def run_test(rank, world_size, port):
|
||||
colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port)
|
||||
|
||||
# check_ring_qk(rank, world_size)
|
||||
check_ring_av(rank, world_size)
|
||||
|
||||
gpc.destroy()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sequence():
|
||||
spawn(run_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sequence()
|
Reference in New Issue
Block a user