Migrated project

This commit is contained in:
zbian
2021-10-28 18:21:23 +02:00
parent 2ebaefc542
commit 404ecbdcc6
409 changed files with 35853 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
import torch
TESSERACT_DIM = 2
TESSERACT_DEP = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True

View File

@@ -0,0 +1,3 @@
#!/bin/bash
python -m torch.distributed.launch test_2p5d.py --nproc_per_node 8 --host $HOST --port 29516 --world_size 8

View File

@@ -0,0 +1,41 @@
import pytest
from colossalai.core import global_context as gpc
from colossalai.initialize import init_dist
from test_layer import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
from test_operation import check_AB, check_ABT, check_ATB
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=8, mode='2.5d', depth=2),
),
)
def check_operations():
check_AB()
check_ABT()
check_ATB()
def check_layer():
check_linear()
check_layernorm()
check_attention()
check_mlp()
check_transformerlayer()
@pytest.mark.dist
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
def test_2p5d():
init_dist(config=CONFIG)
gpc.set_seed()
check_layer()
check_operations()
gpc.destroy()
if __name__ == '__main__':
test_2p5d()

View File

@@ -0,0 +1,265 @@
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn import (Linear2p5D, LayerNorm2p5D, TransformerSelfAttention2p5D, TransformerMLP2p5D,
TransformerLayer2p5D)
from colossalai.utils import get_current_device
from colossalai.utils import print_rank_0
from common import *
def check_linear():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = Linear2p5D(
INPUT_SIZE,
OUTPUT_SIZE,
dtype=dtype,
skip_bias_add=False)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
W_master = torch.randn(W_shape, dtype=dtype, device=device)
torch.distributed.broadcast(W_master, src=0)
W = torch.chunk(W_master, TESSERACT_DIM, dim=0)[i]
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[j]
W = W.clone()
W.requires_grad = True
B_shape = (OUTPUT_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
B = B.clone()
B.requires_grad = True
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
bias = layer.bias
A_master = A_master.clone()
A_master.requires_grad = True
W_master = W_master.clone()
W_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, W_master) + B_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('linear forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
W_grad = W_master.grad
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(W_grad, layer.weight.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear backward: pass')
def check_layernorm():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layernorm = LayerNorm2p5D(
INPUT_SIZE,
dtype=dtype)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layernorm(A)
A_master = A_master.clone()
A_master.requires_grad = True
E_master = torch.sum(A_master, dim=-1, keepdim=True)
E_master /= INPUT_SIZE
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
V_master /= INPUT_SIZE
V_master = V_master - E_master * E_master
V_master = 1.0 / torch.sqrt(V_master + EPS)
C_master = (A_master - E_master) * V_master
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('layer norm forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
def check_attention():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = TransformerSelfAttention2p5D(
HIDDEN_SIZE, NUM_ATTENTION_HEADS,
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5,
dtype=dtype,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('self attention forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('self attention backward: pass')
def check_mlp():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = TransformerMLP2p5D(
HIDDEN_SIZE,
mlp_ratio=1,
dropout_prob=0.5,
act_func='gelu',
dtype=dtype,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
out = layer(A)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('mlp forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('mlp backward: pass')
def check_transformerlayer():
device = get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
NUM_ATTENTION_HEADS = 2
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = TransformerLayer2p5D(
HIDDEN_SIZE,
NUM_ATTENTION_HEADS,
act_func='gelu',
attention_dropout_prob=0.5,
hidden_dropout_prob=0.5,
dtype=dtype,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
out = layer(A, attention_mask)
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
print_rank_0('transformerlayer forward: pass')
grad_shape = out.shape
grad = torch.randn(grad_shape, dtype=dtype, device=device)
out.backward(grad)
assert A.grad.shape == A.shape
print_rank_0('transformerlayer backward: pass')

View File

@@ -0,0 +1,239 @@
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, \
Matmul_ATB_2p5D
from colossalai.utils import get_current_device
from colossalai.utils import print_rank_0
from common import *
def check_AB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM)
out = Matmul_AB_2p5D.apply(
A, B,
TESSERACT_DIM, TESSERACT_DEP, out_shape,
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
C_master = torch.matmul(A_master, B_master)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
# check forward correctness
check_equal(out, C)
print_rank_0('AB forward: pass')
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
C_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
# check backward correctness
check_equal(A_grad, A.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
# check backward correctness
check_equal(B_grad, B.grad)
print_rank_0('AB backward: pass')
def check_ABT():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
device = get_current_device()
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
C = C.clone()
C.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
B = B.clone()
B.requires_grad = True
out = Matmul_ABT_2p5D.apply(
C, B,
TESSERACT_DIM, TESSERACT_DEP, (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM),
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
C_master = C_master.clone()
C_master.requires_grad = True
B_master = B_master.clone()
B_master.requires_grad = True
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
check_equal(out, A)
print_rank_0('ABT forward: pass')
grad_shape = A_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
# backward
out.backward(grad)
A_master.backward(grad_master)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(C_grad, C.grad)
B_grad = B_master.grad
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(B_grad, B.grad)
print_rank_0('ABT backward: pass')
def check_ATB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
A = A.clone()
A.requires_grad = True
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
C_master = torch.randn(C_shape, dtype=dtype, device=device)
torch.distributed.broadcast(C_master, src=0)
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
C = C.clone()
C.requires_grad = True
out = Matmul_ATB_2p5D.apply(
A, C,
TESSERACT_DIM, TESSERACT_DEP, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
i, j, k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_DEP,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size)
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = C_master.clone()
C_master.requires_grad = True
B_master = torch.matmul(
A_master.view(-1, A_master.shape[-1]).transpose(0, 1),
C_master.view(-1, C_master.shape[-1]))
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
check_equal(out, B)
print_rank_0('ATB forward: pass')
grad_shape = B_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
out.backward(grad)
B_master.backward(grad_master)
A_grad = A_master.grad
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
C_grad = C_master.grad
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(C_grad, C.grad)
print_rank_0('ATB backward: pass')