[legacy] move communication and nn to legacy and refactor logger (#4671)

* [legacy] move communication to legacy (#4640)

* [legacy] refactor logger and clean up legacy codes (#4654)

* [legacy] make logger independent to gpc

* [legacy] make optim independent to registry

* [legacy] move test engine to legacy

* [legacy] move nn to legacy (#4656)

* [legacy] move nn to legacy

* [checkpointio] fix save hf config

* [test] remove useledd rpc pp test

* [legacy] fix nn init

* [example] skip tutorial hybriad parallel example

* [devops] test doc check

* [devops] test doc check
This commit is contained in:
Hongxin Liu
2023-09-11 16:24:28 +08:00
committed by GitHub
parent 536397cc95
commit 554aa9592e
170 changed files with 781 additions and 758 deletions

View File

@@ -0,0 +1,552 @@
import torch
import torch.distributed as dist
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.nn import (
Classifier1D,
Embedding1D,
Linear1D_Col,
Linear1D_Row,
VanillaClassifier,
VocabParallelClassifier1D,
VocabParallelCrossEntropyLoss1D,
VocabParallelEmbedding1D,
)
from colossalai.utils import get_current_device, print_rank_0
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear_col():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=0)[i]
W = W.clone()
W.requires_grad = True
B_shape = (OUTPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
dist.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = B.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
check_equal(out, C)
print_rank_0('linear_col forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
dist.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear_col backward: pass')
def check_linear_row():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
A = A.clone()
A.requires_grad = True
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
W = W.clone()
W.requires_grad = True
B_shape = (INPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
dist.broadcast(B_master, src=0)
B = B_master.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = C_master.clone()
check_equal(out, C)
print_rank_0('linear_row forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
dist.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear_row backward: pass')
def check_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('embed backward: pass')
def check_vocab_parallel_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('vocab parallel embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('vocab parallel embed backward: pass')
def check_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
env.parallel_input_1d = False
parallel_input_1d = env.parallel_input_1d
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True)
layer_master = layer_master.to(dtype).to(device)
W_master = layer_master.weight.data
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
layer.weight.data.copy_(W)
B_master = layer_master.bias.data
dist.broadcast(B_master, src=0)
B = B_master.clone()
layer.bias.data.copy_(B)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
if parallel_input_1d:
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
A = A.clone()
else:
A = A_master.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = C_master.clone()
check_equal(out, C)
print_rank_0('classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
if parallel_input_1d:
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i]
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier (no given weight) backward: pass')
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
W_master = layer_master.weight.data
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=0)[i]
layer.weight.data.copy_(W)
B_master = layer_master.bias.data
dist.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
layer.bias.data.copy_(B)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = A_master.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
check_equal(out, C)
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, layer.bias.grad)
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
def check_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
env.parallel_input_1d = False
layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = C_master.clone()
check_equal(out, C)
print_rank_0('classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('classifier (given embed weight) backward: pass')
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
env.parallel_input_1d = False
layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False)
layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=-1)[i]
check_equal(out, C)
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
dist.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
def check_vocab_parallel_loss():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
criterion = VocabParallelCrossEntropyLoss1D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=-1)[i]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('vocab parallel loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i]
check_equal(out_grad, out.grad)
print_rank_0('vocab parallel loss backward: pass')
@torch.no_grad()
def check_linear_row_stream_inference():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
stream_chunk_num = 4
assert HIDDEN_SIZE % stream_chunk_num == 0
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num)
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
dist.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=-1)[i]
A = A.clone()
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
dist.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[i]
W = W.clone()
B_shape = (INPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
dist.broadcast(B_master, src=0)
B = B_master.clone()
layer.weight = Parameter(W)
layer.bias = Parameter(B)
layer.chunk_weight()
layer.eval()
out = layer(A)
A_master = A_master.clone()
W_master = W_master.clone()
B_master = B_master.clone()
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = C_master.clone()
check_equal(out, C)
print_rank_0('linear_row forward: pass')

View File

@@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
DEPTH = 4
BATCH_SIZE = 8
SEQ_LENGTH = 8
IMG_SIZE = 16
HIDDEN_SIZE = 8
NUM_CLASSES = 8
VOCAB_SIZE = 16
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True

View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
from checks_1d.check_layer_1d import *
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),)
def check_layer(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_col()
check_linear_row()
check_embed()
check_vocab_parallel_embed()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_vocab_parallel_loss()
check_linear_row_stream_inference()
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_1d():
spawn(check_layer, 4)
if __name__ == '__main__':
test_1d()

View File

@@ -0,0 +1,752 @@
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.nn import (
Classifier2D,
CrossEntropyLoss2D,
Embedding2D,
LayerNorm2D,
Linear2D,
PatchEmbedding2D,
VanillaClassifier,
VanillaPatchEmbedding,
VocabParallelClassifier2D,
VocabParallelCrossEntropyLoss2D,
VocabParallelEmbedding2D,
)
from colossalai.utils import get_current_device, print_rank_0
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = HIDDEN_SIZE
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = Linear2D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=0)[i]
W = torch.chunk(W, DEPTH, dim=-1)[j]
W = W.clone()
W.requires_grad = True
B_shape = (OUTPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=-1)[j]
B = torch.chunk(B, DEPTH, dim=-1)[i]
B = B.clone()
B.requires_grad = True
layer.weight.data.copy_(W)
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master) + B_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('linear forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear backward: pass')
def check_layernorm():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layernorm = LayerNorm2D(INPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layernorm(A)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('layer norm forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
def check_embed():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('embed backward: pass')
def check_patch_embed():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = PatchEmbedding2D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer.cls_token)
torch.nn.init.ones_(layer.pos_embed)
layer = layer.to(device)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
proj_weight_master = layer_master.weight.data
torch.distributed.broadcast(proj_weight_master, src=0)
proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[j]
proj_weight = torch.chunk(proj_weight, DEPTH, dim=0)[i]
layer.weight.data.copy_(proj_weight)
proj_bias_master = layer_master.bias.data
torch.distributed.broadcast(proj_bias_master, src=0)
proj_bias = torch.chunk(proj_bias_master, DEPTH, dim=0)[j]
proj_bias = torch.chunk(proj_bias, DEPTH, dim=0)[i]
layer.bias.data.copy_(proj_bias)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(A)
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('patch embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[j]
cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[i]
check_equal(cls_grad, layer.cls_token.grad)
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[j]
pos_grad = torch.chunk(pos_grad, DEPTH, dim=-1)[i]
check_equal(pos_grad, layer.pos_embed.grad)
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, layer.weight.grad)
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
check_equal(bias_grad, layer.bias.grad)
print_rank_0('patch embed backward: pass')
def check_vocab_parallel_embed():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('vocab parallel embed backward: pass')
def check_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = NUM_CLASSES
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = Classifier2D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, DEPTH, dim=-1)[j]
W = torch.chunk(W, DEPTH, dim=-1)[i]
W = W.clone()
layer.weight.data.copy_(W)
# W.requires_grad = True
B_shape = (OUTPUT_SIZE,)
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, DEPTH, dim=0)[j]
B = B_master.clone()
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, DEPTH, dim=0)[i]
# C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
# B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier (no given weight) backward: pass')
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
bias = torch.chunk(bias, DEPTH)[i]
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
B_grad = torch.chunk(B_grad, DEPTH)[j]
B_grad = torch.chunk(B_grad, DEPTH)[i]
check_equal(B_grad, layer.bias.grad)
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
def check_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
embed = Embedding2D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
embed.weight.data.copy_(weight)
layer = Classifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
check_equal(out, C)
print_rank_0('classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('classifier (given embed weight) backward: pass')
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
embed = VocabParallelEmbedding2D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier2D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
def check_loss():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
criterion = CrossEntropyLoss2D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('cross entropy loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
check_equal(out_grad, out.grad)
print_rank_0('cross entropy loss backward: pass')
def check_vocab_parallel_loss():
device = get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
criterion = VocabParallelCrossEntropyLoss2D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = torch.chunk(out, DEPTH, dim=-1)[j]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('vocab parallel cross entropy loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[j]
check_equal(out_grad, out.grad)
print_rank_0('vocab parallel cross entropy loss backward: pass')
# def check_attention():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
# layer = TransformerSelfAttention2D(
# HIDDEN_SIZE,
# NUM_ATTENTION_HEADS,
# attention_dropout_prob=0.5,
# hidden_dropout_prob=0.5,
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
# A = torch.chunk(A, DEPTH, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
# print_rank_0('self attention forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('self attention backward: pass')
# def check_mlp():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
# layer = TransformerMLP2D(
# HIDDEN_SIZE,
# dropout_prob=0.5,
# act_func='gelu',
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
# A = torch.chunk(A, DEPTH, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# out = layer(A)
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
# print_rank_0('mlp forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('mlp backward: pass')
# def check_transformerlayer():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
# layer = TransformerLayer2D(HIDDEN_SIZE,
# NUM_ATTENTION_HEADS,
# act_func='gelu',
# attention_dropout_prob=0.5,
# hidden_dropout_prob=0.5)
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, DEPTH, dim=0)[i]
# A = torch.chunk(A, DEPTH, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
# print_rank_0('transformerlayer forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('transformerlayer backward: pass')

View File

@@ -0,0 +1,213 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
from colossalai.utils import get_current_device, print_rank_0
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal
def check_AB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH)
out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, B_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
# check forward correctness
check_equal(out, C)
print_rank_0('AB forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
# check backward correctness
check_equal(A_grad, A.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
# check backward correctness
check_equal(B_grad, B.grad)
print_rank_0('AB backward: pass')
def check_ABT():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
device = get_current_device()
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = C.clone()
C.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank,
pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
C_master = C_master.clone()
C_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
check_equal(out, A)
print_rank_0('ABT forward: pass')
grad_shape = A_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
# backward
out.backward(grad)
A_master.backward(grad_master)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
check_equal(C_grad, C.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
check_equal(B_grad, B.grad)
print_rank_0('ABT backward: pass')
def check_ATB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
dtype = torch.float
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
A = A.clone()
A.requires_grad = True
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = C.clone()
C.requires_grad = True
out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank,
pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = C_master.clone()
C_master.requires_grad = True
B_master = torch.matmul(
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]))
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[j]
check_equal(out, B)
print_rank_0('ATB forward: pass')
grad_shape = B_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
out.backward(grad)
B_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
check_equal(A_grad, A.grad)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
check_equal(C_grad, C.grad)
print_rank_0('ATB backward: pass')

View File

@@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
DEPTH = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
NUM_CLASSES = 8
VOCAB_SIZE = 16
IMG_SIZE = 16
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)

View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
from checks_2d.check_layer_2d import (
check_classifier_given_embed_weight,
check_classifier_no_given_weight,
check_embed,
check_layernorm,
check_linear,
check_loss,
check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight,
check_vocab_parallel_embed,
check_vocab_parallel_loss,
)
from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),)
def check_operations():
check_AB()
check_ABT()
check_ATB()
def check_layer():
check_linear()
check_layernorm()
check_embed()
check_patch_embed()
check_vocab_parallel_embed()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_loss()
check_vocab_parallel_loss()
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = True
# check_operations()
check_layer()
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_2d():
spawn(check_layer_and_operation, 4)
if __name__ == '__main__':
test_2d()

View File

@@ -0,0 +1,765 @@
import torch
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.nn import (
Classifier2p5D,
CrossEntropyLoss2p5D,
Embedding2p5D,
LayerNorm2p5D,
Linear2p5D,
PatchEmbedding2p5D,
VanillaClassifier,
VanillaPatchEmbedding,
VocabParallelClassifier2p5D,
VocabParallelCrossEntropyLoss2p5D,
VocabParallelEmbedding2p5D,
)
from colossalai.utils import get_current_device, print_rank_0
from .common import *
def check_linear():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, TESSERACT_DIM, dim=0)[i]
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[j]
W = W.clone()
W.requires_grad = True
B_shape = (OUTPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
B = B.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
bias = layer.bias
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master) + B_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('linear forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear backward: pass')
def check_layernorm():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layernorm(A)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('layer norm forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
def check_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('embed backward: pass')
def check_patch_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer.cls_token)
torch.nn.init.ones_(layer.pos_embed)
layer = layer.to(device)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
proj_weight_master = layer_master.weight.data
torch.distributed.broadcast(proj_weight_master, src=0)
proj_weight = torch.chunk(proj_weight_master, TESSERACT_DIM, dim=0)[j]
proj_weight = torch.chunk(proj_weight, TESSERACT_DIM, dim=0)[i]
layer.weight.data.copy_(proj_weight)
proj_bias_master = layer_master.bias.data
torch.distributed.broadcast(proj_bias_master, src=0)
proj_bias = torch.chunk(proj_bias_master, TESSERACT_DIM, dim=0)[j]
proj_bias = torch.chunk(proj_bias, TESSERACT_DIM, dim=0)[i]
layer.bias.data.copy_(proj_bias)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(A)
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('patch embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, TESSERACT_DIM, dim=-1)[j]
cls_grad = torch.chunk(cls_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(cls_grad, layer.cls_token.grad)
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, TESSERACT_DIM, dim=-1)[j]
pos_grad = torch.chunk(pos_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(pos_grad, layer.pos_embed.grad)
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
check_equal(B_grad, layer.weight.grad)
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j]
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i]
check_equal(bias_grad, layer.bias.grad)
print_rank_0('patch embed backward: pass')
def check_vocab_parallel_embed():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i]
embed.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = embed(A)
A_master = A_master.clone()
C_master = embed_master(A_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel embed forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('vocab parallel embed backward: pass')
def check_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = NUM_CLASSES
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
layer = Classifier2p5D(INPUT_SIZE, OUTPUT_SIZE)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randint(5, A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (OUTPUT_SIZE, INPUT_SIZE)
W_master = torch.randint(5, W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
# W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
W = torch.chunk(W_master, TESSERACT_DIM, dim=-1)[j]
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[i]
W = W.clone()
layer.weight.data.copy_(W)
# W.requires_grad = True
B_shape = (OUTPUT_SIZE,)
B_master = torch.randint(5, B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
# B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
B = B_master.clone()
layer.bias.data.copy_(B)
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
# C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
# B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier (no given weight) backward: pass')
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(dtype).to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=0)[i]
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[j]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, TESSERACT_DIM)[j]
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layer(A)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = layer_master.weight.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = layer_master.bias.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j]
if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
def check_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
weight = torch.chunk(weight, TESSERACT_DIM, dim=-1)[i]
embed.weight.data.copy_(weight)
layer = Classifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
check_equal(out, C)
print_rank_0('classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('classifier (given embed weight) backward: pass')
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, TESSERACT_DIM, dim=-1)[j]
weight = torch.chunk(weight, TESSERACT_DIM, dim=0)[i]
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
out = layer(embed(A))
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
grad = grad.clone()
out.backward(grad)
grad_master = grad_master.clone()
C_master.backward(grad_master)
W_grad = embed_master.weight.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
def check_loss():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
criterion = CrossEntropyLoss2p5D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('cross entropy loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
check_equal(out_grad, out.grad)
print_rank_0('cross entropy loss backward: pass')
def check_vocab_parallel_loss():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
criterion = VocabParallelCrossEntropyLoss2p5D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i]
out = torch.chunk(out, TESSERACT_DIM, dim=-1)[j]
out = out.clone()
out.requires_grad = True
loss = criterion(out, target_master)
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('vocab parallel cross entropy loss forward: pass')
loss.backward()
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(out_grad, out.grad)
print_rank_0('vocab parallel cross entropy loss backward: pass')
# def check_attention():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
# layer = TransformerSelfAttention2p5D(
# HIDDEN_SIZE, NUM_ATTENTION_HEADS,
# attention_dropout_prob=0.5,
# hidden_dropout_prob=0.5,
# dtype=dtype,
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
# print_rank_0('self attention forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('self attention backward: pass')
# def check_mlp():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
# layer = TransformerMLP2p5D(
# HIDDEN_SIZE,
# mlp_ratio=1,
# dropout_prob=0.5,
# act_func='gelu',
# dtype=dtype,
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# out = layer(A)
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
# print_rank_0('mlp forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('mlp backward: pass')
# def check_transformerlayer():
# device = get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
# i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
# j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
# k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
# layer = TransformerLayer2p5D(
# HIDDEN_SIZE,
# NUM_ATTENTION_HEADS,
# act_func='gelu',
# attention_dropout_prob=0.5,
# hidden_dropout_prob=0.5,
# dtype=dtype,
# )
# A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
# A_master = torch.randn(A_shape, dtype=dtype, device=device)
# torch.distributed.broadcast(A_master, src=0)
# A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
# A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
# A = A.clone()
# A.requires_grad = True
# mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
# attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
# out = layer(A, attention_mask)
# assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
# print_rank_0('transformerlayer forward: pass')
# grad_shape = out.shape
# grad = torch.randn(grad_shape, dtype=dtype, device=device)
# out.backward(grad)
# assert A.grad.shape == A.shape
# print_rank_0('transformerlayer backward: pass')

View File

@@ -0,0 +1,215 @@
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D
from colossalai.utils import get_current_device, print_rank_0
from .common import *
def check_AB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM)
out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank,
pipeline_parallel_size, tensor_parallel_size)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, B_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
# check forward correctness
check_equal(out, C)
print_rank_0('AB forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
# check backward correctness
check_equal(A_grad, A.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
# check backward correctness
check_equal(B_grad, B.grad)
print_rank_0('AB backward: pass')
def check_ABT():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
device = get_current_device()
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
C = C.clone()
C.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM,
(BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k,
ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank,
pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
C_master = C_master.clone()
C_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
check_equal(out, A)
print_rank_0('ABT forward: pass')
grad_shape = A_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
# backward
out.backward(grad)
A_master.backward(grad_master)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(C_grad, C.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(B_grad, B.grad)
print_rank_0('ABT backward: pass')
def check_ATB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
C = C.clone()
C.requires_grad = True
out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size,
tensor_parallel_size)
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = C_master.clone()
C_master.requires_grad = True
B_master = torch.matmul(
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]))
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
check_equal(out, B)
print_rank_0('ATB forward: pass')
grad_shape = B_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
B_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(C_grad, C.grad)
print_rank_0('ATB backward: pass')

View File

@@ -0,0 +1,14 @@
import torch
TESSERACT_DIM = 2
TESSERACT_DEP = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
NUM_CLASSES = 8
VOCAB_SIZE = 16
IMG_SIZE = 16
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2)

View File

@@ -0,0 +1,57 @@
import pytest
import torch
from checks_2p5d.check_layer_2p5d import *
from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2.5d', depth=1),
),)
def check_operations():
check_AB()
check_ABT()
check_ATB()
def check_layer():
check_linear()
check_layernorm()
check_embed()
check_patch_embed()
check_vocab_parallel_embed()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_classifier_given_embed_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_loss()
check_vocab_parallel_loss()
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = True
check_operations()
check_layer()
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_2p5d():
spawn(check_layer_and_operation, 4)
if __name__ == '__main__':
test_2p5d()

View File

@@ -0,0 +1,875 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import time
import torch
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context
from colossalai.legacy.nn import (
Classifier3D,
CrossEntropyLoss3D,
Embedding3D,
LayerNorm3D,
Linear3D,
PatchEmbedding3D,
VanillaClassifier,
VanillaPatchEmbedding,
VocabParallelClassifier3D,
VocabParallelCrossEntropyLoss3D,
VocabParallelEmbedding3D,
)
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device, print_rank_0
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Linear3D(INPUT_SIZE, OUTPUT_SIZE, bias=True)
layer = layer.to(device)
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data.transpose(0, 1).contiguous()
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_layernorm():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
norm = LayerNorm3D(INPUT_SIZE, eps=1e-6)
norm = norm.to(device)
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
norm_master = norm_master.to(device)
weight_master = norm_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH)[k]
norm.weight.data.copy_(weight)
bias_master = norm_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[k]
norm.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
fwd_start = time.time()
out = norm(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = norm_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
bias_grad = norm_master.weight.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad)))
bias_grad = norm_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Classifier3D(INPUT_SIZE, NUM_CLASSES, bias=True)
layer = layer.to(device)
layer_master = VanillaClassifier(INPUT_SIZE, NUM_CLASSES, bias=True)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
layer.bias.data.copy_(bias_master)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
else:
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
rank, layer.weight.grad is None))
bias_grad = layer_master.bias.grad
logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelClassifier3D(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(device)
layer_master = VanillaClassifier(INPUT_SIZE, VOCAB_SIZE, bias=True)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
bias = torch.chunk(bias_master, DEPTH)[j]
layer.bias.data.copy_(bias)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[k]
A = torch.chunk(A, DEPTH, dim=0)[j]
A = A.clone()
A.requires_grad = True
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format(
rank, check_equal(A_grad, A.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format(
rank, check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
embed = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(dtype).to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
embed.weight.data.copy_(weight)
layer = Classifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(dtype).to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(dtype).to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(embed(A))
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
if j == k:
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
rank, check_equal(B_grad, embed.weight.grad)))
else:
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
rank, embed.weight.grad is None))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
embed = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(device)
embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
embed_master = embed_master.to(device)
weight_master = embed_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
embed.weight.data.copy_(weight)
layer = VocabParallelClassifier3D(HIDDEN_SIZE, VOCAB_SIZE, weight=embed.weight, bias=False)
layer = layer.to(device)
layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, weight=embed_master.weight, bias=False)
layer_master = layer_master.to(device)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(embed(A))
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[j]
C = torch.chunk(C, DEPTH, dim=0)[k]
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
grad = torch.chunk(grad, DEPTH, dim=0)[k]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = embed_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad,
embed.weight.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_patch_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = PatchEmbedding3D(IMG_SIZE, 4, 3, HIDDEN_SIZE)
torch.nn.init.ones_(layer.cls_token)
torch.nn.init.ones_(layer.pos_embed)
layer = layer.to(device)
layer_master = VanillaPatchEmbedding(IMG_SIZE, 4, 3, HIDDEN_SIZE)
torch.nn.init.ones_(layer_master.cls_token)
torch.nn.init.ones_(layer_master.pos_embed)
layer_master = layer_master.to(device)
proj_weight_master = layer_master.weight.data
torch.distributed.broadcast(proj_weight_master, src=0)
proj_weight = torch.chunk(proj_weight_master, DEPTH, dim=0)[k]
layer.weight.data.copy_(proj_weight)
proj_bias_master = layer_master.bias.data
torch.distributed.broadcast(proj_bias_master, src=0)
proj_bias = torch.chunk(proj_bias_master, DEPTH)[k]
layer.bias.data.copy_(proj_bias)
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
A_master = torch.randn(A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
print_rank_0(
'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start), logger)
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
grad_master = grad_master.clone()
C_master.backward(grad_master)
cls_grad_master = layer_master.cls_token.grad
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
pos_grad_master = layer_master.pos_embed.grad
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank,
check_equal(pos_grad, layer.pos_embed.grad)))
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank,
check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank,
check_equal(bias_grad, layer.bias.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = Embedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
fwd_end - fwd_start),
ranks=[0])
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
logger = get_dist_logger()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
layer = VocabParallelEmbedding3D(VOCAB_SIZE, HIDDEN_SIZE)
layer = layer.to(device)
layer_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE)
layer_master = layer_master.to(device)
weight_master = layer_master.weight.data
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[j]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
weight = torch.chunk(weight, DEPTH, dim=-1)[k]
layer.weight.data.copy_(weight)
A_shape = (BATCH_SIZE, SEQ_LENGTH)
A_master = torch.randint(VOCAB_SIZE, A_shape, device=device)
torch.distributed.broadcast(A_master, src=0)
A = A_master.clone()
fwd_start = time.time()
out = layer(A)
torch.cuda.synchronize()
fwd_end = time.time()
logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
ranks=[0])
A_master = A_master.clone()
C_master = layer_master(A_master)
C = torch.chunk(C_master, DEPTH, dim=0)[i]
C = torch.chunk(C, DEPTH, dim=-1)[k]
C = torch.chunk(C, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
grad = grad.clone()
bwd_start = time.time()
out.backward(grad)
torch.cuda.synchronize()
bwd_end = time.time()
logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
grad_master = grad_master.clone()
C_master.backward(grad_master)
B_grad = layer_master.weight.grad
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
check_equal(B_grad,
layer.weight.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_loss():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
criterion = CrossEntropyLoss3D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = torch.chunk(out, DEPTH, dim=0)[j]
out = out.clone()
out.requires_grad = True
fwd_start = time.time()
loss = criterion(out, target_master)
fwd_end = time.time()
logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape),
fwd_end - fwd_start),
ranks=[0])
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
bwd_start = time.time()
loss.backward()
bwd_end = time.time()
logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start
def check_vocab_parallel_loss():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
output_parallel_mode = get_parallel_mode_from_env(OUTPUT_GROUP_3D)
j = global_context.get_local_rank(input_parallel_mode)
i = global_context.get_local_rank(weight_parallel_mode)
k = global_context.get_local_rank(output_parallel_mode)
criterion = VocabParallelCrossEntropyLoss3D()
criterion_master = torch.nn.CrossEntropyLoss()
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
out = torch.chunk(out, DEPTH, dim=-1)[k]
out = torch.chunk(out, DEPTH, dim=0)[j]
out = out.clone()
out.requires_grad = True
fwd_start = time.time()
loss = criterion(out, target_master)
fwd_end = time.time()
logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
ranks=[0])
out_master = out_master.clone()
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
bwd_start = time.time()
loss.backward()
bwd_end = time.time()
logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
loss_master.backward()
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k]
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
return fwd_end - fwd_start, bwd_end - bwd_start

View File

@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
DEPTH = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
NUM_CLASSES = 8
NUM_BLOCKS = 2
IMG_SIZE = 16
VOCAB_SIZE = 16
def check_equal(A, B):
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
assert eq, f"\nA = {A}\nB = {B}"
return eq

View File

@@ -0,0 +1,64 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pytest
import torch
from checks_3d.check_layer_3d import (
check_classifier_no_given_weight,
check_embed,
check_layernorm,
check_linear,
check_loss,
check_patch_embed,
check_vocab_parallel_classifier_given_embed_weight,
check_vocab_parallel_classifier_no_given_weight,
check_vocab_parallel_embed,
check_vocab_parallel_loss,
)
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
CONFIG = dict(
parallel=dict(
pipeline=1,
tensor=dict(mode='3d', size=8),
),
seed=42,
)
def check_layer():
check_linear()
check_layernorm()
check_classifier_no_given_weight()
check_vocab_parallel_classifier_no_given_weight()
check_vocab_parallel_classifier_given_embed_weight()
check_embed()
check_patch_embed()
check_vocab_parallel_embed()
check_loss()
check_vocab_parallel_loss()
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = True
check_layer()
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_3d():
spawn(check_layer_and_operation, 8)
if __name__ == '__main__':
test_3d()

View File

@@ -0,0 +1,377 @@
import random
from typing import List
import numpy as np
import pytest
import torch
import colossalai
from colossalai.legacy.nn.parallel.layers import (
CachedEmbeddingBag,
CachedParamMgr,
EvictionStrategy,
ParallelCachedEmbeddingBag,
ParallelCachedEmbeddingBagTablewise,
TablewiseEmbeddingBagConfig,
)
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8
def set_seed(seed):
"""
To achieve reproducible results, it's necessary to fix random seeds
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def synthesize_1d_sparse_feature(
batch_size,
num_embed,
device,
):
indices_in_batch = batch_size * 2
indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long)
offsets = torch.from_numpy(
np.array([
0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch
])).to(device).long()
return indices, offsets
@pytest.mark.skip
@clear_cache_before_run()
def test_cachemgr():
model = torch.nn.EmbeddingBag(10000, 128)
# 10 chunks, 5 in cuda
mgr = CachedParamMgr(model.weight.detach(), 5)
assert mgr.cuda_row_num == 5
mgr._admit(1)
assert not mgr._chunk_in_cuda(2)
assert mgr._chunk_in_cuda(1)
# print(mgr.cached_chunk_table)
mgr._admit(8)
# now 3 chunk is available
assert mgr.cuda_available_chunk_num == 3
mgr._evict()
assert mgr.cuda_available_chunk_num == 4
mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0))
mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0))
# print(mgr.cached_chunk_table)
# mgr.print_comm_stats()
mgr.flush()
assert mgr.cuda_available_chunk_num == 5
@clear_cache_before_run()
def test_reorder_with_freq():
num_embed = 100
chunk_size = 1
num_chunk = 5
idx_map = torch.randint(10000, size=(num_embed,))
sorted_idx = torch.argsort(idx_map, descending=True).tolist()
chunkid, offset_in_chunk = [], []
for i in range(num_embed):
idx = sorted_idx.index(i)
chunkid.append(idx // chunk_size)
offset_in_chunk.append(idx % chunk_size)
dev = torch.device('cuda')
chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev)
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev)
weight = torch.rand(num_embed, 2)
mgr = CachedParamMgr(weight, num_chunk)
mgr.reorder(idx_map)
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev))
mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor')
mgr_offsets = torch.remainder(indices, chunk_size)
assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
assert torch.allclose(offset_in_chunk, mgr_offsets), \
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
@clear_cache_before_run()
@parameterize('use_LFU', [True, False])
def test_freq_aware_embed(use_LFU: bool):
device = torch.device('cuda', 0)
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
model = CachedEmbeddingBag(NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
mode='mean',
include_last_offset=True,
freeze=False)
assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
for i in range(5):
indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device)
res = model(indices, offsets)
ref_res = ref_model(indices, offsets)
assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}"
grad = torch.rand_like(res)
# comparing gradient here is nontrivial
res.backward(grad)
ref_res.backward(grad)
optimizer.step()
optimizer.zero_grad()
ref_optimizer.step()
ref_optimizer.zero_grad()
model.cache_weight_mgr.flush()
model_weight = model.weight.detach().to(device)
ref_weight = ref_model.weight.detach()
assert torch.allclose(model_weight, ref_weight), \
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
@clear_cache_before_run()
@parameterize('init_freq', [True, False])
def test_lfu_strategy(init_freq: bool):
# minimal test to check behavior
Bag = CachedEmbeddingBag(5,
5,
cache_ratio=3 / 5,
buffer_size=0,
pin_weight=True,
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
warmup_ratio=1.0,
evict_strategy=EvictionStrategy.LFU)
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
offsets = torch.tensor([0], device="cuda:0")
# prepare frequency learning info:
Bag.forward(torch.tensor([2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
# check strategy
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
"LFU strategy behavior failed"
def gather_tensor(tensor, rank, world_size):
gather_list = []
if rank == 0:
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
torch.distributed.gather(tensor, gather_list, dst=0)
return gather_list
def run_parallel_freq_aware_embed_tablewise(rank, world_size):
if world_size != 2:
return
device = torch.device('cuda', torch.cuda.current_device())
# initialize weight
# 3 feature tables. idx: 0~5, 6~10, 11~17
weight_tables = torch.rand(18, 5)
weight_table1 = weight_tables[0:6]
weight_table2 = weight_tables[6:11]
weight_table3 = weight_tables[11:18]
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
embedding_bag_config_list.append(
TablewiseEmbeddingBagConfig(num_embeddings=6,
cuda_row_num=4,
assigned_rank=0,
initial_weight=weight_table1.clone().detach().cpu()))
embedding_bag_config_list.append(
TablewiseEmbeddingBagConfig(num_embeddings=5,
cuda_row_num=4,
assigned_rank=0,
initial_weight=weight_table2.clone().detach().cpu()))
embedding_bag_config_list.append(
TablewiseEmbeddingBagConfig(num_embeddings=7,
cuda_row_num=4,
assigned_rank=1,
initial_weight=weight_table3.clone().detach().cpu()))
if rank == 0:
_weight = torch.cat([weight_table1, weight_table2], 0)
else:
_weight = weight_table3
model = ParallelCachedEmbeddingBagTablewise(
embedding_bag_config_list,
embedding_dim=5,
_weight=_weight,
include_last_offset=True,
cache_ratio=0.5,
buffer_size=0,
evict_strategy=EvictionStrategy.LFU,
)
# explain
'''
batch feature 1 feature 2 feature 3
input0 [1,2,3] [6,7] []
input1 [] [9] [13,15]
input2 [1,5] [6,8] [11]
↑ ↑ ↑
rank 0 rank 0 rank 1
in KJT format
'''
res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
already_split_along_rank=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
if rank == 0:
fake_grad = rand_grad[0:2]
else:
fake_grad = rand_grad[2:]
res.backward(fake_grad)
optimizer.step()
optimizer.zero_grad()
# check correctness
if rank == 0:
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(),
include_last_offset=True,
freeze=False).to(device)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0)
ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
ref_res.backward(ref_fake_grad)
ref_optimizer.step()
ref_optimizer.zero_grad()
model.cache_weight_mgr.flush()
recover_weight = model.cache_weight_mgr.weight.to(device)
ref_weight = ref_model.weight.detach()[:11]
assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
device = torch.device('cuda', torch.cuda.current_device())
num_embed = 100
embed_dim = 16
batch_size = 4
set_seed(4321)
weight = torch.rand(num_embed, embed_dim)
coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None)
# initialize the tensor spec for the embedding weight parameter,
# which is an ColoParameter.
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
model = ParallelCachedEmbeddingBag.from_pretrained(
coloweight,
include_last_offset=True,
freeze=False,
cache_ratio=batch_size * 2 / num_embed,
)
assert model.cache_weight_mgr.weight.device.type == 'cpu'
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}")
assert torch.allclose(weight_in_rank,
model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}"
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
if rank == 0:
ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(),
include_last_offset=True,
freeze=False).to(device)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
set_seed(4321)
for i in range(5):
indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device)
res = model(indices, offsets)
grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device)
grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank]
res.backward(grad_in_rank)
optimizer.step()
optimizer.zero_grad()
res_list = gather_tensor(res.detach(), rank, world_size)
if rank == 0:
ref_res = ref_model(indices, offsets)
recover_res = torch.cat(res_list, dim=0)
assert torch.allclose(ref_res, recover_res)
ref_res.backward(grad)
ref_optimizer.step()
ref_optimizer.zero_grad()
model.cache_weight_mgr.flush()
weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size)
if rank == 0:
recover_weight = torch.cat(weight_list, dim=1)
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# run_parallel_freq_aware_embed_columnwise(rank, world_size)
run_parallel_freq_aware_embed_tablewise(rank, world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_parallel_freq_aware_embed(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
# test_freq_aware_embed(True)
test_parallel_freq_aware_embed(2)
# test_lfu_strategy(False)

View File

@@ -0,0 +1,21 @@
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.nn import TransformerSelfAttentionRing
from colossalai.utils import get_current_device
def check_selfattention():
WORLD_SIZE = gpc.get_world_size(ParallelMode.SEQUENCE)
SUB_SEQ_LENGTH = 8
BATCH = 4
HIDDEN_SIZE = 16
layer = TransformerSelfAttentionRing(16, 8, 8, 0.1)
layer = layer.to(get_current_device())
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
attention_mask = torch.randint(low=0, high=2,
size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device())
out = layer(hidden_states, attention_mask)

View File

@@ -0,0 +1,139 @@
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence import RingAV, RingQK
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence')))
def check_ring_qk(rank, world_size):
# params
batch_size = 4
num_heads = 4
seq_length = 32
attention_head_size = 32
sub_seq_length = seq_length // world_size
# create master tensors
q = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
k = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(q, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(k, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_q = q.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_k = k.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes
q.requires_grad = True
k.requires_grad = True
q.retain_grad()
k.retain_grad()
sub_q.requires_grad = True
sub_k.requires_grad = True
sub_q.retain_grad()
sub_k.retain_grad()
# compute master attention scores
a = torch.matmul(q, k.transpose(2, 1))
# compute distributed attention scores
ring_qk = RingQK.apply
sub_a = ring_qk(sub_q, sub_k, batch_size, num_heads, sub_seq_length)
# check master and distributed attention scores
sub_master_a = a[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a, sub_master_a, rtol=1e-5, atol=1e-2)
# run master backward
a.retain_grad()
a.mean().backward()
# run distributed backward
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_a, partial_master_a_grad)
# check master and distributed grads
partial_master_q_grad = q.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_q.grad, partial_master_q_grad, rtol=1e-5, atol=1e-2), \
'attention score cannot match'
def check_ring_av(rank, world_size):
# params
batch_size = 4
num_heads = 4
seq_length = 16
attention_head_size = 32
sub_seq_length = seq_length // world_size
# create master tensors
a = torch.rand(batch_size * num_heads, seq_length, seq_length).cuda()
v = torch.rand(batch_size * num_heads, seq_length, attention_head_size).cuda()
dist.broadcast(a, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
dist.broadcast(v, src=0, group=gpc.get_group(ParallelMode.SEQUENCE))
# create distributed tensors
sub_a = a.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
sub_v = v.clone()[:, rank * sub_seq_length:(rank + 1) * sub_seq_length].contiguous()
# set autograd attributes
a.requires_grad = True
v.requires_grad = True
a.retain_grad()
v.retain_grad()
sub_a.requires_grad = True
sub_v.requires_grad = True
sub_a.retain_grad()
sub_v.retain_grad()
# compute master attention scores
out = torch.matmul(a, v)
# compute distributed attention scores
ring_av = RingAV.apply
sub_out = ring_av(sub_a, sub_v, batch_size, num_heads, attention_head_size, sub_seq_length)
# print(f'master output shape: {out.shape}, partial output shape: {sub_out.shape}')
# check master and distributed output
sub_master_out = out[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_out, sub_master_out, rtol=1e-5, atol=1e-2)
# # run master backward
out.retain_grad()
out.mean().backward()
# # run distributed backward
partial_master_out_grad = out.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
torch.autograd.backward(sub_out, partial_master_out_grad)
# # check master and distributed grads
partial_master_a_grad = a.grad[:, rank * sub_seq_length:(rank + 1) * sub_seq_length]
assert torch.allclose(sub_a.grad, partial_master_a_grad, rtol=1e-5, atol=1e-2), \
'attention output cannot match'
def run_test(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port)
# check_ring_qk(rank, world_size)
check_ring_av(rank, world_size)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_sequence():
spawn(run_test, 4)
if __name__ == '__main__':
test_sequence()