mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
Migrated project
This commit is contained in:
11
tests/test_layers/test_2p5d/common.py
Normal file
11
tests/test_layers/test_2p5d/common.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
TESSERACT_DIM = 2
|
||||
TESSERACT_DEP = 2
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGTH = 8
|
||||
HIDDEN_SIZE = 8
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) == True
|
3
tests/test_layers/test_2p5d/test.sh
Normal file
3
tests/test_layers/test_2p5d/test.sh
Normal file
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
python -m torch.distributed.launch test_2p5d.py --nproc_per_node 8 --host $HOST --port 29516 --world_size 8
|
41
tests/test_layers/test_2p5d/test_2p5d.py
Normal file
41
tests/test_layers/test_2p5d/test_2p5d.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import init_dist
|
||||
from test_layer import check_linear, check_layernorm, check_attention, check_mlp, check_transformerlayer
|
||||
from test_operation import check_AB, check_ABT, check_ATB
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=8, mode='2.5d', depth=2),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def check_operations():
|
||||
check_AB()
|
||||
check_ABT()
|
||||
check_ATB()
|
||||
|
||||
|
||||
def check_layer():
|
||||
check_linear()
|
||||
check_layernorm()
|
||||
check_attention()
|
||||
check_mlp()
|
||||
check_transformerlayer()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("This test should be invoked by test.sh in the same folder as it runs on multiple gpus")
|
||||
def test_2p5d():
|
||||
init_dist(config=CONFIG)
|
||||
gpc.set_seed()
|
||||
check_layer()
|
||||
check_operations()
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_2p5d()
|
265
tests/test_layers/test_2p5d/test_layer.py
Normal file
265
tests/test_layers/test_2p5d/test_layer.py
Normal file
@@ -0,0 +1,265 @@
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn import (Linear2p5D, LayerNorm2p5D, TransformerSelfAttention2p5D, TransformerMLP2p5D,
|
||||
TransformerLayer2p5D)
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import print_rank_0
|
||||
from common import *
|
||||
|
||||
|
||||
def check_linear():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = Linear2p5D(
|
||||
INPUT_SIZE,
|
||||
OUTPUT_SIZE,
|
||||
dtype=dtype,
|
||||
skip_bias_add=False)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
|
||||
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(W_master, src=0)
|
||||
W = torch.chunk(W_master, TESSERACT_DIM, dim=0)[i]
|
||||
W = torch.chunk(W, TESSERACT_DIM, dim=-1)[j]
|
||||
W = W.clone()
|
||||
W.requires_grad = True
|
||||
|
||||
B_shape = (OUTPUT_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
layer.weight = Parameter(W)
|
||||
layer.bias = Parameter(B)
|
||||
out = layer(A)
|
||||
bias = layer.bias
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
W_master = W_master.clone()
|
||||
W_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, W_master) + B_master
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('linear forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
W_grad = W_master.grad
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
|
||||
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(W_grad, layer.weight.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[j]
|
||||
if i == 0:
|
||||
check_equal(B_grad, layer.bias.grad)
|
||||
|
||||
print_rank_0('linear backward: pass')
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
EPS = 1e-12
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layernorm = LayerNorm2p5D(
|
||||
INPUT_SIZE,
|
||||
dtype=dtype)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layernorm(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
E_master = torch.sum(A_master, dim=-1, keepdim=True)
|
||||
E_master /= INPUT_SIZE
|
||||
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
|
||||
V_master /= INPUT_SIZE
|
||||
V_master = V_master - E_master * E_master
|
||||
V_master = 1.0 / torch.sqrt(V_master + EPS)
|
||||
C_master = (A_master - E_master) * V_master
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
check_equal(out, C)
|
||||
print_rank_0('layer norm forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
print_rank_0('layer norm backward: pass')
|
||||
|
||||
|
||||
def check_attention():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
NUM_ATTENTION_HEADS = 2
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = TransformerSelfAttention2p5D(
|
||||
HIDDEN_SIZE, NUM_ATTENTION_HEADS,
|
||||
attention_dropout_prob=0.5,
|
||||
hidden_dropout_prob=0.5,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
|
||||
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
||||
|
||||
out = layer(A, attention_mask)
|
||||
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
|
||||
print_rank_0('self attention forward: pass')
|
||||
|
||||
grad_shape = out.shape
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
out.backward(grad)
|
||||
assert A.grad.shape == A.shape
|
||||
print_rank_0('self attention backward: pass')
|
||||
|
||||
|
||||
def check_mlp():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = TransformerMLP2p5D(
|
||||
HIDDEN_SIZE,
|
||||
mlp_ratio=1,
|
||||
dropout_prob=0.5,
|
||||
act_func='gelu',
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out = layer(A)
|
||||
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
|
||||
print_rank_0('mlp forward: pass')
|
||||
|
||||
grad_shape = out.shape
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
out.backward(grad)
|
||||
assert A.grad.shape == A.shape
|
||||
print_rank_0('mlp backward: pass')
|
||||
|
||||
|
||||
def check_transformerlayer():
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
NUM_ATTENTION_HEADS = 2
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
layer = TransformerLayer2p5D(
|
||||
HIDDEN_SIZE,
|
||||
NUM_ATTENTION_HEADS,
|
||||
act_func='gelu',
|
||||
attention_dropout_prob=0.5,
|
||||
hidden_dropout_prob=0.5,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
mask_shape = (BATCH_SIZE // TESSERACT_DIM, NUM_ATTENTION_HEADS // TESSERACT_DIM, SEQ_LENGTH, SEQ_LENGTH)
|
||||
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
||||
|
||||
out = layer(A, attention_mask)
|
||||
assert out.shape == (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, INPUT_SIZE // TESSERACT_DIM)
|
||||
print_rank_0('transformerlayer forward: pass')
|
||||
|
||||
grad_shape = out.shape
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
out.backward(grad)
|
||||
assert A.grad.shape == A.shape
|
||||
print_rank_0('transformerlayer backward: pass')
|
239
tests/test_layers/test_2p5d/test_operation.py
Normal file
239
tests/test_layers/test_2p5d/test_operation.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import torch
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, \
|
||||
Matmul_ATB_2p5D
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import print_rank_0
|
||||
from common import *
|
||||
|
||||
|
||||
def check_AB():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
dtype = torch.float
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
|
||||
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM)
|
||||
out = Matmul_AB_2p5D.apply(
|
||||
A, B,
|
||||
TESSERACT_DIM, TESSERACT_DEP, out_shape,
|
||||
i, j, k,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
data_parallel_rank,
|
||||
pipeline_parallel_rank,
|
||||
pipeline_parallel_size,
|
||||
tensor_parallel_size)
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, B_master)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
# check forward correctness
|
||||
check_equal(out, C)
|
||||
print_rank_0('AB forward: pass')
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
# check backward correctness
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
# check backward correctness
|
||||
check_equal(B_grad, B.grad)
|
||||
print_rank_0('AB backward: pass')
|
||||
|
||||
|
||||
def check_ABT():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
dtype = torch.float
|
||||
device = get_current_device()
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(C_master, src=0)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
C = C.clone()
|
||||
C.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
|
||||
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out = Matmul_ABT_2p5D.apply(
|
||||
C, B,
|
||||
TESSERACT_DIM, TESSERACT_DEP, (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM),
|
||||
i, j, k,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
data_parallel_rank,
|
||||
pipeline_parallel_rank,
|
||||
pipeline_parallel_size,
|
||||
tensor_parallel_size)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
C_master = C_master.clone()
|
||||
C_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, A)
|
||||
print_rank_0('ABT forward: pass')
|
||||
|
||||
grad_shape = A_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
# backward
|
||||
out.backward(grad)
|
||||
|
||||
A_master.backward(grad_master)
|
||||
C_grad = C_master.grad
|
||||
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
|
||||
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(C_grad, C.grad)
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(B_grad, B.grad)
|
||||
print_rank_0('ABT backward: pass')
|
||||
|
||||
|
||||
def check_ATB():
|
||||
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||||
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||||
ParallelMode.PIPELINE)
|
||||
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
device = get_current_device()
|
||||
dtype = torch.float
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(C_master, src=0)
|
||||
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
|
||||
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
|
||||
C = C.clone()
|
||||
C.requires_grad = True
|
||||
|
||||
out = Matmul_ATB_2p5D.apply(
|
||||
A, C,
|
||||
TESSERACT_DIM, TESSERACT_DEP, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
|
||||
i, j, k,
|
||||
ParallelMode.PARALLEL_2P5D_ROW,
|
||||
ParallelMode.PARALLEL_2P5D_COL,
|
||||
ParallelMode.PARALLEL_2P5D_DEP,
|
||||
data_parallel_rank,
|
||||
pipeline_parallel_rank,
|
||||
pipeline_parallel_size,
|
||||
tensor_parallel_size)
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = C_master.clone()
|
||||
C_master.requires_grad = True
|
||||
B_master = torch.matmul(
|
||||
A_master.view(-1, A_master.shape[-1]).transpose(0, 1),
|
||||
C_master.view(-1, C_master.shape[-1]))
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
|
||||
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(out, B)
|
||||
print_rank_0('ATB forward: pass')
|
||||
|
||||
grad_shape = B_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
B_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(A_grad, A.grad)
|
||||
|
||||
C_grad = C_master.grad
|
||||
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
|
||||
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
|
||||
check_equal(C_grad, C.grad)
|
||||
print_rank_0('ATB backward: pass')
|
Reference in New Issue
Block a user