[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -30,7 +30,7 @@ def check_linear():
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = Linear2p5D(INPUT_SIZE, OUTPUT_SIZE, dtype=dtype, skip_bias_add=False)
@@ -50,7 +50,7 @@ def check_linear():
W = W.clone()
W.requires_grad = True
B_shape = (OUTPUT_SIZE)
B_shape = OUTPUT_SIZE
B_master = torch.randn(B_shape, dtype=dtype, device=device)
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j]
@@ -60,7 +60,7 @@ def check_linear():
layer.weight = Parameter(W)
layer.bias = Parameter(B)
out = layer(A)
bias = layer.bias
layer.bias
A_master = A_master.clone()
A_master.requires_grad = True
@@ -73,7 +73,7 @@ def check_linear():
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('linear forward: pass')
print_rank_0("linear forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
@@ -100,7 +100,7 @@ def check_linear():
if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('linear backward: pass')
print_rank_0("linear backward: pass")
def check_layernorm():
@@ -111,7 +111,7 @@ def check_layernorm():
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layernorm = LayerNorm2p5D(INPUT_SIZE, dtype=dtype)
@@ -138,7 +138,7 @@ def check_layernorm():
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('layer norm forward: pass')
print_rank_0("layer norm forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
@@ -152,7 +152,7 @@ def check_layernorm():
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=0)[i]
A_grad = torch.chunk(A_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(A_grad, A.grad)
print_rank_0('layer norm backward: pass')
print_rank_0("layer norm backward: pass")
def check_embed():
@@ -160,7 +160,7 @@ def check_embed():
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
@@ -184,7 +184,7 @@ def check_embed():
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('embed forward: pass')
print_rank_0("embed forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
@@ -200,7 +200,7 @@ def check_embed():
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('embed backward: pass')
print_rank_0("embed backward: pass")
def check_patch_embed():
@@ -208,7 +208,7 @@ def check_patch_embed():
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = PatchEmbedding2p5D(IMG_SIZE, 4, 3, HIDDEN_SIZE, dtype=dtype)
torch.nn.init.ones_(layer.cls_token)
@@ -242,7 +242,7 @@ def check_patch_embed():
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('patch embed forward: pass')
print_rank_0("patch embed forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
@@ -274,7 +274,7 @@ def check_patch_embed():
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[j]
bias_grad = torch.chunk(bias_grad, TESSERACT_DIM)[i]
check_equal(bias_grad, layer.bias.grad)
print_rank_0('patch embed backward: pass')
print_rank_0("patch embed backward: pass")
def check_vocab_parallel_embed():
@@ -282,7 +282,7 @@ def check_vocab_parallel_embed():
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
@@ -306,7 +306,7 @@ def check_vocab_parallel_embed():
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel embed forward: pass')
print_rank_0("vocab parallel embed forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
@@ -322,7 +322,7 @@ def check_vocab_parallel_embed():
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
check_equal(B_grad, embed.weight.grad)
print_rank_0('vocab parallel embed backward: pass')
print_rank_0("vocab parallel embed backward: pass")
def check_classifier_no_given_weight():
@@ -374,7 +374,7 @@ def check_classifier_no_given_weight():
# C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('classifier (no given weight) forward: pass')
print_rank_0("classifier (no given weight) forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
@@ -401,7 +401,7 @@ def check_classifier_no_given_weight():
# if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('classifier (no given weight) backward: pass')
print_rank_0("classifier (no given weight) backward: pass")
def check_vocab_parallel_classifier_no_given_weight():
@@ -409,7 +409,7 @@ def check_vocab_parallel_classifier_no_given_weight():
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
layer = VocabParallelClassifier2p5D(HIDDEN_SIZE, VOCAB_SIZE, bias=True)
layer = layer.to(dtype).to(device)
@@ -442,7 +442,7 @@ def check_vocab_parallel_classifier_no_given_weight():
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (no given weight) forward: pass')
print_rank_0("vocab parallel classifier (no given weight) forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
@@ -468,7 +468,7 @@ def check_vocab_parallel_classifier_no_given_weight():
B_grad = torch.chunk(B_grad, TESSERACT_DIM)[j]
if i == 0:
check_equal(B_grad, layer.bias.grad)
print_rank_0('vocab parallel classifier (no given weight) backward: pass')
print_rank_0("vocab parallel classifier (no given weight) backward: pass")
def check_classifier_given_embed_weight():
@@ -476,7 +476,7 @@ def check_classifier_given_embed_weight():
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = Embedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
@@ -504,7 +504,7 @@ def check_classifier_given_embed_weight():
C_master = layer_master(embed_master(A_master))
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
check_equal(out, C)
print_rank_0('classifier (given embed weight) forward: pass')
print_rank_0("classifier (given embed weight) forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
@@ -520,7 +520,7 @@ def check_classifier_given_embed_weight():
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('classifier (given embed weight) backward: pass')
print_rank_0("classifier (given embed weight) backward: pass")
def check_vocab_parallel_classifier_given_embed_weight():
@@ -528,7 +528,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
embed = VocabParallelEmbedding2p5D(VOCAB_SIZE, HIDDEN_SIZE)
embed = embed.to(dtype).to(device)
@@ -557,7 +557,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
C = torch.chunk(C_master, TESSERACT_DIM, dim=0)[i]
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
check_equal(out, C)
print_rank_0('vocab parallel classifier (given embed weight) forward: pass')
print_rank_0("vocab parallel classifier (given embed weight) forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
@@ -574,15 +574,15 @@ def check_vocab_parallel_classifier_given_embed_weight():
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=-1)[j]
W_grad = torch.chunk(W_grad, TESSERACT_DIM, dim=0)[i]
check_equal(W_grad, embed.weight.grad)
print_rank_0('vocab parallel classifier (given embed weight) backward: pass')
print_rank_0("vocab parallel classifier (given embed weight) backward: pass")
def check_loss():
device = get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
criterion = CrossEntropyLoss2p5D()
criterion_master = torch.nn.CrossEntropyLoss()
@@ -601,7 +601,7 @@ def check_loss():
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('cross entropy loss forward: pass')
print_rank_0("cross entropy loss forward: pass")
loss.backward()
loss_master.backward()
@@ -609,7 +609,7 @@ def check_loss():
out_grad = out_master.grad
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
check_equal(out_grad, out.grad)
print_rank_0('cross entropy loss backward: pass')
print_rank_0("cross entropy loss backward: pass")
def check_vocab_parallel_loss():
@@ -617,7 +617,7 @@ def check_vocab_parallel_loss():
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
criterion = VocabParallelCrossEntropyLoss2p5D()
criterion_master = torch.nn.CrossEntropyLoss()
@@ -637,7 +637,7 @@ def check_vocab_parallel_loss():
out_master.requires_grad = True
loss_master = criterion_master(out_master, target_master)
check_equal(loss, loss_master)
print_rank_0('vocab parallel cross entropy loss forward: pass')
print_rank_0("vocab parallel cross entropy loss forward: pass")
loss.backward()
loss_master.backward()
@@ -646,7 +646,7 @@ def check_vocab_parallel_loss():
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=0)[i]
out_grad = torch.chunk(out_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(out_grad, out.grad)
print_rank_0('vocab parallel cross entropy loss backward: pass')
print_rank_0("vocab parallel cross entropy loss backward: pass")
# def check_attention():

View File

@@ -11,10 +11,12 @@ from .common import *
def check_AB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
pipeline_parallel_rank = (
0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
)
pipeline_parallel_size = (
1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE)
)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
@@ -39,11 +41,23 @@ def check_AB():
B.requires_grad = True
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM)
out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank,
pipeline_parallel_size, tensor_parallel_size)
out = Matmul_AB_2p5D.apply(
A,
B,
TESSERACT_DIM,
out_shape,
i,
j,
k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size,
)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
(BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
B_master = B_master.clone()
@@ -53,7 +67,7 @@ def check_AB():
C = torch.chunk(C, TESSERACT_DIM, dim=-1)[j]
# check forward correctness
check_equal(out, C)
print_rank_0('AB forward: pass')
print_rank_0("AB forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
@@ -75,15 +89,17 @@ def check_AB():
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
# check backward correctness
check_equal(B_grad, B.grad)
print_rank_0('AB backward: pass')
print_rank_0("AB backward: pass")
def check_ABT():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
pipeline_parallel_rank = (
0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
)
pipeline_parallel_size = (
1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE)
)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
@@ -109,12 +125,23 @@ def check_ABT():
B = B.clone()
B.requires_grad = True
out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM,
(BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k,
ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank,
pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
out = Matmul_ABT_2p5D.apply(
C,
B,
TESSERACT_DIM,
(BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM),
i,
j,
k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size,
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
(BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
C_master = C_master.clone()
C_master.requires_grad = True
B_master = B_master.clone()
@@ -123,7 +150,7 @@ def check_ABT():
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
check_equal(out, A)
print_rank_0('ABT forward: pass')
print_rank_0("ABT forward: pass")
grad_shape = A_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
@@ -144,15 +171,17 @@ def check_ABT():
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=0)[i]
B_grad = torch.chunk(B_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(B_grad, B.grad)
print_rank_0('ABT backward: pass')
print_rank_0("ABT backward: pass")
def check_ATB():
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
ParallelMode.PIPELINE)
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
pipeline_parallel_rank = (
0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
)
pipeline_parallel_size = (
1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(ParallelMode.PIPELINE)
)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
@@ -178,22 +207,34 @@ def check_ATB():
C = C.clone()
C.requires_grad = True
out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size,
tensor_parallel_size)
out = Matmul_ATB_2p5D.apply(
A,
C,
TESSERACT_DIM,
(HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM),
i,
j,
k,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size,
)
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
(HIDDEN_SIZE, 4 * HIDDEN_SIZE)
A_master = A_master.clone()
A_master.requires_grad = True
C_master = C_master.clone()
C_master.requires_grad = True
B_master = torch.matmul(
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]))
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])
)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
check_equal(out, B)
print_rank_0('ATB forward: pass')
print_rank_0("ATB forward: pass")
grad_shape = B_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
@@ -213,4 +254,4 @@ def check_ATB():
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=0)[i]
C_grad = torch.chunk(C_grad, TESSERACT_DIM, dim=-1)[j]
check_equal(C_grad, C.grad)
print_rank_0('ATB backward: pass')
print_rank_0("ATB backward: pass")

View File

@@ -8,10 +8,12 @@ from colossalai.legacy.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2.5d', depth=1),
),)
CONFIG = dict(
parallel=dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode="2.5d", depth=1),
),
)
def check_operations():
@@ -36,7 +38,7 @@ def check_layer():
def check_layer_and_operation(rank, world_size, port):
disable_existing_loggers()
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
@@ -53,5 +55,5 @@ def test_2p5d():
spawn(check_layer_and_operation, 4)
if __name__ == '__main__':
if __name__ == "__main__":
test_2p5d()