mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 14:12:02 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -73,14 +73,15 @@ def check_linear():
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'linear forward: {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
"linear forward: {0} --> {1} | {2:.3f} s".format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger
|
||||
)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
@@ -93,24 +94,24 @@ def check_linear():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
print_rank_0("linear backward: {:.3f} s".format(bwd_end - bwd_start), logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} linear backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
|
||||
logger.info("Rank {} linear backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad.transpose(0, 1)
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
logger.info("Rank {} linear backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
logger.info('Rank {} linear backward (bias_grad): {}'.format(rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
logger.info("Rank {} linear backward (bias_grad): {}".format(rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -157,8 +158,11 @@ def check_layernorm():
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start), logger)
|
||||
"layer norm forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start
|
||||
),
|
||||
logger,
|
||||
)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
@@ -166,7 +170,7 @@ def check_layernorm():
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} layernorm forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} layernorm forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
@@ -179,22 +183,22 @@ def check_layernorm():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
print_rank_0("layer norm backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} layernorm backward (input_grad): {}'.format(rank, check_equal(A_grad, A.grad)))
|
||||
logger.info("Rank {} layernorm backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
bias_grad = norm_master.weight.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} layernorm backward (weight_grad): {}'.format(rank, check_equal(bias_grad, norm.weight.grad)))
|
||||
logger.info("Rank {} layernorm backward (weight_grad): {}".format(rank, check_equal(bias_grad, norm.weight.grad)))
|
||||
|
||||
bias_grad = norm_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} layernorm backward (bias_grad): {}'.format(rank, check_equal(bias_grad, norm.bias.grad)))
|
||||
logger.info("Rank {} layernorm backward (bias_grad): {}".format(rank, check_equal(bias_grad, norm.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -241,14 +245,17 @@ def check_classifier_no_given_weight():
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
"classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start
|
||||
),
|
||||
logger,
|
||||
)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
@@ -261,7 +268,7 @@ def check_classifier_no_given_weight():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
print_rank_0("classifier (no given weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
@@ -269,21 +276,29 @@ def check_classifier_no_given_weight():
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} classifier (no given weight) backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
logger.info(
|
||||
"Rank {} classifier (no given weight) backward (input_grad): {}".format(rank, check_equal(A_grad, A.grad))
|
||||
)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)))
|
||||
logger.info(
|
||||
"Rank {} classifier (no given weight) backward (weight_grad): {}".format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info('Rank {} classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, layer.weight.grad is None))
|
||||
logger.info(
|
||||
"Rank {} classifier (no given weight) backward (weight_grad): {}".format(rank, layer.weight.grad is None)
|
||||
)
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
logger.info('Rank {} classifier (no given weight) backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
logger.info(
|
||||
"Rank {} classifier (no given weight) backward (bias_grad): {}".format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)
|
||||
)
|
||||
)
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -333,15 +348,18 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
"vocab parallel classifier (no given weight) forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start
|
||||
),
|
||||
logger,
|
||||
)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} vocab parallel classifier (no given weight) forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
@@ -355,8 +373,9 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('vocab parallel classifier (no given weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
print_rank_0(
|
||||
"vocab parallel classifier (no given weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger
|
||||
)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
@@ -364,20 +383,29 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
logger.info(
|
||||
"Rank {} vocab parallel classifier (no given weight) backward (input_grad): {}".format(
|
||||
rank, check_equal(A_grad, A.grad)
|
||||
)
|
||||
)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)))
|
||||
logger.info(
|
||||
"Rank {} vocab parallel classifier (no given weight) backward (weight_grad): {}".format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)
|
||||
)
|
||||
)
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
logger.info('Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
logger.info(
|
||||
"Rank {} vocab parallel classifier (no given weight) backward (bias_grad): {}".format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)
|
||||
)
|
||||
)
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -423,13 +451,16 @@ def check_classifier_given_embed_weight():
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
"classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start
|
||||
),
|
||||
logger,
|
||||
)
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
@@ -442,7 +473,7 @@ def check_classifier_given_embed_weight():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
print_rank_0("classifier (given embed weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
@@ -450,11 +481,15 @@ def check_classifier_given_embed_weight():
|
||||
B_grad = embed_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
if j == k:
|
||||
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, embed.weight.grad)))
|
||||
logger.info(
|
||||
"Rank {} classifier (given embed weight) backward (weight_grad): {}".format(
|
||||
rank, check_equal(B_grad, embed.weight.grad)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info('Rank {} classifier (given embed weight) backward (weight_grad): {}'.format(
|
||||
rank, embed.weight.grad is None))
|
||||
logger.info(
|
||||
"Rank {} classifier (given embed weight) backward (weight_grad): {}".format(rank, embed.weight.grad is None)
|
||||
)
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -501,14 +536,17 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
"vocab parallel classifier (given embed weight) forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start
|
||||
),
|
||||
logger,
|
||||
)
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(embed_master(A_master))
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} vocab parallel classifier (given embed weight) forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} vocab parallel classifier (given embed weight) forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
@@ -522,8 +560,9 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('vocab parallel classifier (given embed weight) backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
print_rank_0(
|
||||
"vocab parallel classifier (given embed weight) backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger
|
||||
)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
@@ -532,9 +571,9 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad,
|
||||
embed.weight.grad)))
|
||||
logger.info(
|
||||
"Rank {} vocab parallel embed backward (weight_grad): {}".format(rank, check_equal(B_grad, embed.weight.grad))
|
||||
)
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -543,7 +582,7 @@ def check_patch_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
dtype = torch.float32
|
||||
torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -582,15 +621,18 @@ def check_patch_embed():
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'patch embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start), logger)
|
||||
"patch embed forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start
|
||||
),
|
||||
logger,
|
||||
)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} patch embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} patch embed forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
@@ -604,29 +646,32 @@ def check_patch_embed():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('patch embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), logger)
|
||||
print_rank_0("patch embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), logger)
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} patch embed backward (cls_grad): {}'.format(rank, check_equal(cls_grad, layer.cls_token.grad)))
|
||||
logger.info("Rank {} patch embed backward (cls_grad): {}".format(rank, check_equal(cls_grad, layer.cls_token.grad)))
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} patch embed backward (pos_embed_grad): {}'.format(rank,
|
||||
check_equal(pos_grad, layer.pos_embed.grad)))
|
||||
logger.info(
|
||||
"Rank {} patch embed backward (pos_embed_grad): {}".format(rank, check_equal(pos_grad, layer.pos_embed.grad))
|
||||
)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} patch embed backward (proj_weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad, layer.weight.grad)))
|
||||
logger.info(
|
||||
"Rank {} patch embed backward (proj_weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad))
|
||||
)
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} patch embed backward (proj_bias_grad): {}'.format(rank,
|
||||
check_equal(bias_grad, layer.bias.grad)))
|
||||
logger.info(
|
||||
"Rank {} patch embed backward (proj_bias_grad): {}".format(rank, check_equal(bias_grad, layer.bias.grad))
|
||||
)
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -635,7 +680,7 @@ def check_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
dtype = torch.float32
|
||||
torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -664,16 +709,17 @@ def check_embed():
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
logger.info('embed forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(A.shape), tuple(out.shape),
|
||||
fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
logger.info(
|
||||
"embed forward: pass | {0} --> {1} | {2:.3f} s".format(tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} embed forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
@@ -686,14 +732,14 @@ def check_embed():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
logger.info('embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
logger.info("embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
|
||||
B_grad = layer_master.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} embed backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
logger.info("Rank {} embed backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -702,7 +748,7 @@ def check_vocab_parallel_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_dist_logger()
|
||||
dtype = torch.float32
|
||||
torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -733,16 +779,19 @@ def check_vocab_parallel_embed():
|
||||
out = layer(A)
|
||||
torch.cuda.synchronize()
|
||||
fwd_end = time.time()
|
||||
logger.info('vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
logger.info(
|
||||
"vocab parallel embed forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start
|
||||
),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
A_master = A_master.clone()
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
logger.info("Rank {} vocab parallel embed forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=device)
|
||||
@@ -755,7 +804,7 @@ def check_vocab_parallel_embed():
|
||||
out.backward(grad)
|
||||
torch.cuda.synchronize()
|
||||
bwd_end = time.time()
|
||||
logger.info('vocab parallel embed backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
logger.info("vocab parallel embed backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
grad_master = grad_master.clone()
|
||||
C_master.backward(grad_master)
|
||||
@@ -764,9 +813,9 @@ def check_vocab_parallel_embed():
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} vocab parallel embed backward (weight_grad): {}'.format(rank,
|
||||
check_equal(B_grad,
|
||||
layer.weight.grad)))
|
||||
logger.info(
|
||||
"Rank {} vocab parallel embed backward (weight_grad): {}".format(rank, check_equal(B_grad, layer.weight.grad))
|
||||
)
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -798,25 +847,28 @@ def check_loss():
|
||||
fwd_start = time.time()
|
||||
loss = criterion(out, target_master)
|
||||
fwd_end = time.time()
|
||||
logger.info('cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(tuple(out.shape), tuple(loss.shape),
|
||||
fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
logger.info(
|
||||
"cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start
|
||||
),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
logger.info('Rank {} cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
|
||||
logger.info("Rank {} cross entropy loss forward: {}".format(rank, check_equal(loss, loss_master)))
|
||||
|
||||
bwd_start = time.time()
|
||||
loss.backward()
|
||||
bwd_end = time.time()
|
||||
logger.info('cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
logger.info("cross entropy loss backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
loss_master.backward()
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
|
||||
logger.info("Rank {} cross entropy loss backward: {}".format(rank, check_equal(out_grad, out.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
@@ -825,7 +877,7 @@ def check_vocab_parallel_loss():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -852,25 +904,28 @@ def check_vocab_parallel_loss():
|
||||
fwd_start = time.time()
|
||||
loss = criterion(out, target_master)
|
||||
fwd_end = time.time()
|
||||
logger.info('vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start),
|
||||
ranks=[0])
|
||||
logger.info(
|
||||
"vocab parallel cross entropy loss forward: pass | {0} --> {1} | {2:.3f} s".format(
|
||||
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start
|
||||
),
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
logger.info('Rank {} vocab parallel cross entropy loss forward: {}'.format(rank, check_equal(loss, loss_master)))
|
||||
logger.info("Rank {} vocab parallel cross entropy loss forward: {}".format(rank, check_equal(loss, loss_master)))
|
||||
|
||||
bwd_start = time.time()
|
||||
loss.backward()
|
||||
bwd_end = time.time()
|
||||
logger.info('vocab parallel cross entropy loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start), ranks=[0])
|
||||
logger.info("vocab parallel cross entropy loss backward: pass | {:.3f} s".format(bwd_end - bwd_start), ranks=[0])
|
||||
|
||||
loss_master.backward()
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} vocab parallel cross entropy loss backward: {}'.format(rank, check_equal(out_grad, out.grad)))
|
||||
logger.info("Rank {} vocab parallel cross entropy loss backward: {}".format(rank, check_equal(out_grad, out.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
@@ -23,7 +23,7 @@ from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gp
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=1,
|
||||
tensor=dict(mode='3d', size=8),
|
||||
tensor=dict(mode="3d", size=8),
|
||||
),
|
||||
seed=42,
|
||||
)
|
||||
@@ -44,7 +44,7 @@ def check_layer():
|
||||
|
||||
def check_layer_and_operation(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -60,5 +60,5 @@ def test_3d():
|
||||
spawn(check_layer_and_operation, 8)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_3d()
|
||||
|
Reference in New Issue
Block a user