mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
Migrated project
This commit is contained in:
15
tests/test_layers/test_3d/common.py
Normal file
15
tests/test_layers/test_3d/common.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
DEPTH = 2
|
||||
BATCH_SIZE = 512
|
||||
SEQ_LENGTH = 128
|
||||
HIDDEN_SIZE = 512
|
||||
NUM_CLASSES = 10
|
||||
NUM_BLOCKS = 6
|
||||
IMG_SIZE = 32
|
||||
|
||||
def check_equal(A, B):
|
||||
return torch.allclose(A, B, rtol=1e-5, atol=1e-2)
|
22
tests/test_layers/test_3d/test.sh
Normal file
22
tests/test_layers/test_3d/test.sh
Normal file
@@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
python -m torch.distributed.launch test_2d.py --nproc_per_node 8 test_3d.py --host $HOST --port 29516 --world_size 8
|
||||
|
||||
# expected test output
|
||||
# distributed environment initialized
|
||||
# AB forward: pass
|
||||
# AB backward: pass
|
||||
# ABT forward: pass
|
||||
# ABT backward: pass
|
||||
# ATB forward: pass
|
||||
# ATB backward: pass
|
||||
# linear backward: pass
|
||||
# linear backward: pass
|
||||
# layer norm forward: pass
|
||||
# layer norm backward: pass
|
||||
# self attention forward: pass
|
||||
# self attention backward: pass
|
||||
# mlp forward: pass
|
||||
# mlp backward: pass
|
||||
# transformerlayer forward: pass
|
||||
# transformerlayer backward: pass
|
58
tests/test_layers/test_3d/test_3d.py
Normal file
58
tests/test_layers/test_3d/test_3d.py
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.initialize import init_dist
|
||||
|
||||
from test_layer import *
|
||||
from test_operation import *
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=1, tensor=dict(mode='3d', size=8)),
|
||||
seed=0)
|
||||
|
||||
|
||||
def check_operations():
|
||||
check_AB()
|
||||
check_ABT()
|
||||
check_ATB()
|
||||
check_add()
|
||||
check_mul()
|
||||
check_sum()
|
||||
# check_pooler()
|
||||
|
||||
|
||||
def check_layer():
|
||||
logger = get_global_dist_logger()
|
||||
liear_fwd_time, linear_bwd_time = check_linear()
|
||||
norm_fwd_time, norm_bwd_time = check_layernorm()
|
||||
attn_fwd_time, attn_bwd_time = check_attention()
|
||||
mlp_fwd_time, mlp_bwd_time = check_mlp()
|
||||
head_fwd_time, head_bwd_time = check_head()
|
||||
embed_fwd_time, embed_bwd_time = check_embed()
|
||||
loss_fwd_time, loss_bwd_time = check_loss()
|
||||
block_fwd_time = norm_fwd_time + attn_fwd_time + norm_fwd_time + mlp_fwd_time
|
||||
block_bwd_time = norm_bwd_time + attn_bwd_time + norm_bwd_time + mlp_bwd_time
|
||||
fwd_time = embed_fwd_time + NUM_BLOCKS * block_fwd_time + norm_fwd_time + head_fwd_time + loss_fwd_time
|
||||
bwd_time = embed_bwd_time + NUM_BLOCKS * block_bwd_time + norm_bwd_time + head_bwd_time + loss_bwd_time
|
||||
logger.info('ViT forward time: {:.3f} s | backward time: {:.3f} s'.format(
|
||||
fwd_time, bwd_time),
|
||||
ranks=[0])
|
||||
|
||||
|
||||
def _test_main():
|
||||
# init dist
|
||||
init_dist(CONFIG)
|
||||
logger = get_global_dist_logger()
|
||||
logger.info('Distributed environment is initialzied.', ranks=[0])
|
||||
|
||||
global_context.set_seed()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# check operation
|
||||
check_operations()
|
||||
|
||||
# check layers
|
||||
check_layer()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_main()
|
19
tests/test_layers/test_3d/test_conn.py
Normal file
19
tests/test_layers/test_3d/test_conn.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.initialize import parse_args
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
ARGS = parse_args()
|
||||
size = ARGS.world_size
|
||||
rank = ARGS.local_rank
|
||||
|
||||
init_method = f'tcp://{ARGS.host}:{ARGS.port}'
|
||||
dist.init_process_group(backend='nccl', rank=rank, world_size=size, init_method=init_method)
|
||||
print('Rank {} / {}'.format(dist.get_rank(), dist.get_world_size()))
|
||||
|
||||
SIZE = 8
|
||||
tensor = torch.randn(SIZE)
|
||||
tensor = tensor.to(get_current_device())
|
||||
dist.all_reduce(tensor)
|
||||
print('Rank {0}: {1}'.format(rank, tensor.detach().cpu().numpy().tolist()))
|
640
tests/test_layers/test_3d/test_layer.py
Normal file
640
tests/test_layers/test_3d/test_layer.py
Normal file
@@ -0,0 +1,640 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.registry import LAYERS, LOSSES
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
from common import *
|
||||
|
||||
|
||||
def check_linear():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
layer = LAYERS.get_module('Linear3D')(INPUT_SIZE,
|
||||
OUTPUT_SIZE,
|
||||
ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT,
|
||||
dtype=dtype,
|
||||
bias=True)
|
||||
torch.nn.init.zeros_(layer.bias)
|
||||
torch.nn.init.ones_(layer.weight)
|
||||
layer = layer.to(device)
|
||||
layer_master = torch.nn.Linear(INPUT_SIZE, OUTPUT_SIZE)
|
||||
torch.nn.init.zeros_(layer_master.bias)
|
||||
torch.nn.init.ones_(layer_master.weight)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'linear forward: {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} linear forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape,
|
||||
dtype=dtype,
|
||||
device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
bwd_end = time.time()
|
||||
print_rank_0('linear backward: {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} linear backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
B_grad = layer_master.weight.grad.transpose(0, 1)
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
if j == k:
|
||||
bias_grad = layer_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
||||
logger.info('Rank {} linear backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, layer.bias.grad)))
|
||||
else:
|
||||
logger.info('Rank {} linear backward (bias_grad): {}'.format(
|
||||
rank,
|
||||
# np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
|
||||
layer.bias.grad is None))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
norm = LAYERS.get_module('LayerNorm3D')(INPUT_SIZE,
|
||||
ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT,
|
||||
eps=1e-6,
|
||||
dtype=dtype)
|
||||
norm = norm.to(device)
|
||||
norm_master = torch.nn.LayerNorm(INPUT_SIZE, eps=1e-6)
|
||||
norm_master = norm_master.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = norm(A)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'layer norm forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = norm_master(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} layernorm forward: {}'.format(rank,
|
||||
check_equal(out, C)))
|
||||
# time.sleep(rank)
|
||||
# logger.info('Rank {0} master:\n{1}\nRank {0} out:\n{2}\nRank {0} true:\n{3}\n'.
|
||||
# format(rank,
|
||||
# C_master.detach().cpu().numpy().tolist(),
|
||||
# out.detach().cpu().numpy().tolist(),
|
||||
# C.detach().cpu().numpy().tolist()))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
bwd_end = time.time()
|
||||
print_rank_0(
|
||||
'layer norm backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} layernorm backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
if j == k:
|
||||
bias_grad = norm_master.weight.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, norm.weight.grad)))
|
||||
else:
|
||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(
|
||||
rank,
|
||||
# np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
|
||||
norm.weight.grad is None))
|
||||
|
||||
if j == k:
|
||||
bias_grad = norm_master.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
||||
logger.info('Rank {} linear backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, norm.bias.grad)))
|
||||
else:
|
||||
logger.info('Rank {} linear backward (bias_grad): {}'.format(
|
||||
rank,
|
||||
# np.count_nonzero(layer.bias.grad.detach().cpu().numpy()) == 0))
|
||||
norm.bias.grad is None))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_attention():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
NUM_ATTENTION_HEADS = 2
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
layer = LAYERS.get_module('ViTSelfAttention3D')(HIDDEN_SIZE,
|
||||
NUM_ATTENTION_HEADS,
|
||||
0.,
|
||||
0.1,
|
||||
dtype=dtype,
|
||||
bias=True)
|
||||
layer = layer.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH,
|
||||
SEQ_LENGTH // DEPTH, SEQ_LENGTH // DEPTH)
|
||||
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'self attention forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
|
||||
grad_shape = out.shape
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
bwd_end = time.time()
|
||||
print_rank_0(
|
||||
'self attention backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_mlp():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
layer = LAYERS.get_module('ViTMLP3D')(HIDDEN_SIZE,
|
||||
1,
|
||||
0.1,
|
||||
'gelu',
|
||||
dtype=dtype,
|
||||
bias=True)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'mlp forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
|
||||
grad_shape = out.shape
|
||||
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
bwd_end = time.time()
|
||||
print_rank_0('mlp backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
class Testvithead(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[:, 0]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def check_head():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
head = LAYERS.get_module('ViTHead3D')(INPUT_SIZE,
|
||||
NUM_CLASSES,
|
||||
dtype=dtype,
|
||||
bias=True)
|
||||
torch.nn.init.zeros_(head.linear.bias)
|
||||
torch.nn.init.ones_(head.linear.weight)
|
||||
head = head.to(device)
|
||||
|
||||
layer = Testvithead(INPUT_SIZE, NUM_CLASSES, bias=True)
|
||||
torch.nn.init.zeros_(layer.linear.bias)
|
||||
torch.nn.init.ones_(layer.linear.weight)
|
||||
layer = layer.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = head(A)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'head forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer(A_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} head forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape,
|
||||
dtype=dtype,
|
||||
device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
||||
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
bwd_end = time.time()
|
||||
print_rank_0('head backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
# if j == 0:
|
||||
logger.info('Rank {} head backward (input_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
# else:
|
||||
# logger.info('Rank {} head backward (input_grad): {}'.format(
|
||||
# # rank, check_equal(A_grad, A.grad)))
|
||||
# rank,
|
||||
# A.grad is None))
|
||||
|
||||
B_grad = layer.linear.weight.grad.transpose(0, 1)
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
pad_shape = (B_grad.shape[0], math.ceil(B_grad.shape[-1] / DEPTH) * DEPTH -
|
||||
B_grad.shape[-1])
|
||||
B_grad = torch.cat(
|
||||
[B_grad, torch.zeros(pad_shape, dtype=dtype, device=device)], dim=-1)
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
logger.info('Rank {} head backward (weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, head.linear.weight.grad)))
|
||||
|
||||
if j == k:
|
||||
bias_grad = layer.linear.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
pad_shape = (math.ceil(bias_grad.shape[0] / DEPTH) * DEPTH -
|
||||
bias_grad.shape[0], )
|
||||
bias_grad = torch.cat(
|
||||
[bias_grad,
|
||||
torch.zeros(pad_shape, dtype=dtype, device=device)])
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
||||
logger.info('Rank {} head backward (bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, head.linear.bias.grad)))
|
||||
else:
|
||||
logger.info('Rank {} head backward (bias_grad): {}'.format(
|
||||
rank,
|
||||
# np.count_nonzero(
|
||||
# head.linear.bias.grad.detach().cpu().numpy()) == 0))
|
||||
head.linear.bias.grad is None))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
class Testvitembed(torch.nn.Module):
|
||||
def __init__(self, img_size: int, patch_size: int, in_chans: int,
|
||||
embed_size: int, drop_prob: float) -> None:
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv2d(in_chans,
|
||||
embed_size,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size)
|
||||
num_patches = (img_size // patch_size)**2
|
||||
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_size))
|
||||
self.pos_embed = torch.nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_size))
|
||||
self.pos_drop = torch.nn.Dropout(drop_prob)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
return x
|
||||
|
||||
|
||||
def check_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float32
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
layer = LAYERS.get_module('ViTPatchEmbedding3D')(IMG_SIZE, 4, 3,
|
||||
HIDDEN_SIZE, 0.)
|
||||
torch.nn.init.zeros_(layer.proj.bias)
|
||||
torch.nn.init.ones_(layer.proj.weight)
|
||||
torch.nn.init.ones_(layer.cls_token)
|
||||
torch.nn.init.ones_(layer.pos_embed)
|
||||
layer = layer.to(device)
|
||||
|
||||
layer_master = Testvitembed(IMG_SIZE, 4, 3, HIDDEN_SIZE, 0.)
|
||||
torch.nn.init.zeros_(layer_master.proj.bias)
|
||||
torch.nn.init.ones_(layer_master.proj.weight)
|
||||
torch.nn.init.ones_(layer_master.cls_token)
|
||||
torch.nn.init.ones_(layer_master.pos_embed)
|
||||
layer_master = layer_master.to(device)
|
||||
|
||||
A_shape = (BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = A_master.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
out = layer(A)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'embedding forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(A.shape), tuple(out.shape), fwd_end - fwd_start), logger)
|
||||
# out_cls = out[:, 0]
|
||||
# out_tensor = out[:, 1:]
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = layer_master(A_master)
|
||||
# if j == 0:
|
||||
# C_cls = C_master[:, 0]
|
||||
# C_cls = torch.chunk(C_cls, DEPTH, dim=0)[i]
|
||||
# C_cls = torch.chunk(C_cls, DEPTH, dim=-1)[k]
|
||||
# logger.info('Rank {} embed forward (cls): {}'.format(
|
||||
# rank, check_equal(out_cls, C_cls)))
|
||||
# C = C_master[:, 1:]
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} embed forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape,
|
||||
dtype=dtype,
|
||||
device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
# cls_grad = grad_master[:, 0]
|
||||
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=0)[i]
|
||||
# cls_grad = torch.chunk(cls_grad, DEPTH, dim=-1)[k]
|
||||
# grad = grad_master[:, 1:]
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
# grad = torch.cat((torch.unsqueeze(cls_grad, 1), grad), dim=1)
|
||||
bwd_start = time.time()
|
||||
out.backward(grad)
|
||||
bwd_end = time.time()
|
||||
print_rank_0(
|
||||
'embedding backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
# A_grad = A_master.grad
|
||||
# logger.info('Rank {} embed backward (input_grad): {}'.format(
|
||||
# rank, check_equal(A_grad, A.grad)))
|
||||
# time.sleep(0.1 * rank)
|
||||
# logger.info(
|
||||
# 'Rank {0} master:\n{1}\nRank {0} out:\n{2}\nRank {0} true:\n{3}\n'.
|
||||
# format(rank,
|
||||
# A_master.grad.detach().cpu().numpy().tolist(),
|
||||
# A.grad.detach().cpu().numpy().tolist(),
|
||||
# A_grad.detach().cpu().numpy().tolist()), ranks=[0])
|
||||
|
||||
cls_grad_master = layer_master.cls_token.grad
|
||||
cls_grad = torch.chunk(cls_grad_master, DEPTH, dim=-1)[k]
|
||||
# if j == 0:
|
||||
logger.info('Rank {} embed backward (cls_grad): {}'.format(
|
||||
rank, check_equal(cls_grad, layer.cls_token.grad)))
|
||||
# else:.
|
||||
# logger.info('Rank {} embed backward (cls_grad): {}'.format(
|
||||
# rank,
|
||||
# layer.cls_token.grad is None or np.count_nonzero(
|
||||
# layer.cls_token.grad.detach().cpu().numpy()) == 0))
|
||||
|
||||
pos_grad_master = layer_master.pos_embed.grad
|
||||
pos_grad = torch.chunk(pos_grad_master, DEPTH, dim=-1)[k]
|
||||
logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(
|
||||
rank, check_equal(pos_grad, layer.pos_embed.grad)))
|
||||
# if i == 0:
|
||||
# pos_cls_grad = pos_grad[:, 0]
|
||||
# pos_tensor_grad = pos_grad[:, 1:]
|
||||
# pos_tensor_grad = torch.chunk(pos_tensor_grad, DEPTH, dim=1)[j]
|
||||
# if j == 0:
|
||||
# logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(
|
||||
# rank,
|
||||
# check_equal(
|
||||
# torch.cat(
|
||||
# (torch.unsqueeze(pos_cls_grad, 1), pos_tensor_grad),
|
||||
# dim=1), layer.pos_embed.grad)))
|
||||
# else:
|
||||
# logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(
|
||||
# rank, check_equal(pos_tensor_grad, layer.pos_embed.grad[:,
|
||||
# 1:])))
|
||||
# else:
|
||||
# logger.info('Rank {} embed backward (pos_embed_grad): {}'.format(
|
||||
# rank, layer.pos_embed.grad is None))
|
||||
|
||||
B_grad = layer_master.proj.weight.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} embed backward (proj_weight_grad): {}'.format(
|
||||
rank, check_equal(B_grad, layer.proj.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.proj.bias.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[k]
|
||||
logger.info('Rank {} embed backward (proj_bias_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, layer.proj.bias.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
||||
|
||||
|
||||
def check_loss():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
criterion = LOSSES.get_module('CrossEntropyLoss3D')(
|
||||
ParallelMode.PARALLEL_3D_INPUT, ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
criterion_master = torch.nn.CrossEntropyLoss()
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ),
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
out = torch.chunk(out, DEPTH, dim=-1)[k]
|
||||
out = torch.chunk(out, DEPTH, dim=0)[j]
|
||||
out = out.clone()
|
||||
out.requires_grad = True
|
||||
|
||||
fwd_start = time.time()
|
||||
loss = criterion(out, target_master)
|
||||
fwd_end = time.time()
|
||||
print_rank_0(
|
||||
'loss forward: pass | {0} --> {1} | {2:.3f} s'.format(
|
||||
tuple(out.shape), tuple(loss.shape), fwd_end - fwd_start), logger)
|
||||
|
||||
out_master = out_master.clone()
|
||||
out_master.requires_grad = True
|
||||
loss_master = criterion_master(out_master, target_master)
|
||||
logger.info('Rank {} CrossEntropyLoss forward: {}'.format(
|
||||
rank, check_equal(loss, loss_master)))
|
||||
|
||||
bwd_start = time.time()
|
||||
loss.backward()
|
||||
bwd_end = time.time()
|
||||
print_rank_0('loss backward: pass | {:.3f} s'.format(bwd_end - bwd_start),
|
||||
logger)
|
||||
|
||||
loss_master.backward()
|
||||
out_grad = out_master.grad
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[i]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[k]
|
||||
out_grad = torch.chunk(out_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} CrossEntropyLoss backward: {}'.format(
|
||||
rank, check_equal(out_grad, out.grad)))
|
||||
|
||||
return fwd_end - fwd_start, bwd_end - bwd_start
|
465
tests/test_layers/test_3d/test_operation.py
Normal file
465
tests/test_layers/test_3d/test_operation.py
Normal file
@@ -0,0 +1,465 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context
|
||||
from colossalai.logging import get_global_dist_logger
|
||||
from colossalai.nn.layer.parallel_3d._operation import *
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from common import *
|
||||
|
||||
|
||||
def check_AB():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float
|
||||
j = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[k]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[i]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out = Matmul_AB_3D.apply(A, B, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT,
|
||||
ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
C_master = torch.matmul(A_master, B_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
# check forward correctness
|
||||
logger.info('Rank {} AB forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape,
|
||||
dtype=dtype,
|
||||
device=get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
# check backward correctness
|
||||
logger.info('Rank {} AB backward (A_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
# check backward correctness
|
||||
logger.info('Rank {} AB backward (B_grad): {}'.format(
|
||||
rank, check_equal(B_grad, B.grad)))
|
||||
|
||||
|
||||
def check_ABT():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
device = get_current_device()
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(C_master, src=0)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
C = C.clone()
|
||||
C.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[k]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[i]
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out = Matmul_ABT_3D.apply(C, B, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT,
|
||||
ParallelMode.PARALLEL_3D_INPUT)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
C_master = C_master.clone()
|
||||
C_master.requires_grad = True
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} ABT forward: {}'.format(rank, check_equal(out, A)))
|
||||
|
||||
grad_shape = A_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
|
||||
# backward
|
||||
out.backward(grad)
|
||||
|
||||
A_master.backward(grad_master)
|
||||
C_grad = C_master.grad
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} ABT backward (A_grad): {}'.format(
|
||||
rank, check_equal(C_grad, C.grad)))
|
||||
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
logger.info('Rank {} ABT backward (B_grad): {}'.format(
|
||||
rank, check_equal(B_grad, B.grad)))
|
||||
|
||||
|
||||
def check_ATB():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
device = get_current_device()
|
||||
dtype = torch.float
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||||
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(C_master, src=0)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[k]
|
||||
C = C.clone()
|
||||
C.requires_grad = True
|
||||
|
||||
out = Matmul_ATB_3D.apply(A, C, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_OUTPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = C_master.clone()
|
||||
C_master.requires_grad = True
|
||||
B_master = torch.matmul(
|
||||
A_master.view(-1, A_master.shape[-1]).transpose(0, 1),
|
||||
C_master.view(-1, C_master.shape[-1]))
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[k]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[i]
|
||||
logger.info('Rank {} ATB forward: {}'.format(rank, check_equal(out, B)))
|
||||
|
||||
grad_shape = B_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[i]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
B_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} ATB backward (A_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
C_grad = C_master.grad
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
|
||||
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[k]
|
||||
logger.info('Rank {} ATB backward (B_grad): {}'.format(
|
||||
rank, check_equal(C_grad, C.grad)))
|
||||
|
||||
|
||||
def check_add():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
device = get_current_device()
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
bias_shape = (HIDDEN_SIZE, )
|
||||
bias_master = torch.randn(bias_shape,
|
||||
dtype=dtype,
|
||||
device=get_current_device())
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, DEPTH)[j]
|
||||
bias = torch.chunk(bias, DEPTH)[i]
|
||||
bias = bias.clone()
|
||||
bias.requires_grad = True
|
||||
|
||||
out = Add_3D.apply(A, bias, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT,
|
||||
ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
bias_master = bias_master.clone()
|
||||
bias_master.requires_grad = True
|
||||
C_master = A_master + bias_master
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
|
||||
logger.info('Rank {} Add forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} Add backward (A_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
if j == k:
|
||||
bias_grad = bias_master.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
||||
logger.info('Rank {} Add backward (b_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, bias.grad)))
|
||||
else:
|
||||
logger.info('Rank {} Add backward (b_grad): {}'.format(
|
||||
rank,
|
||||
# np.count_nonzero(bias.grad.detach().cpu().numpy()) == 0))
|
||||
bias.grad is None))
|
||||
|
||||
|
||||
def check_mul():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
device = get_current_device()
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
bias_shape = (HIDDEN_SIZE, )
|
||||
bias_master = torch.randn(bias_shape,
|
||||
dtype=dtype,
|
||||
device=get_current_device())
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
bias = torch.chunk(bias_master, DEPTH)[j]
|
||||
bias = torch.chunk(bias, DEPTH)[i]
|
||||
bias = bias.clone()
|
||||
bias.requires_grad = True
|
||||
|
||||
out = Mul_3D.apply(A, bias, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT,
|
||||
ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
bias_master = bias_master.clone()
|
||||
bias_master.requires_grad = True
|
||||
C_master = torch.mul(A_master, bias_master)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
|
||||
logger.info('Rank {} Mul forward: {}'.format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
|
||||
out.backward(grad)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} Mul backward (A_grad): {}'.format(
|
||||
rank, check_equal(A_grad, A.grad)))
|
||||
|
||||
if j == k:
|
||||
bias_grad = bias_master.grad
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
||||
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
||||
logger.info('Rank {} Mul backward (b_grad): {}'.format(
|
||||
rank, check_equal(bias_grad, bias.grad)))
|
||||
else:
|
||||
logger.info('Rank {} Mul backward (b_grad): {}'.format(
|
||||
rank,
|
||||
# np.count_nonzero(bias.grad.detach().cpu().numpy()) == 0))
|
||||
bias.grad is None))
|
||||
|
||||
|
||||
def check_sum():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
device = get_current_device()
|
||||
|
||||
# tensor
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
||||
A = torch.chunk(A, DEPTH, dim=0)[j]
|
||||
A = A.clone()
|
||||
A.requires_grad = True
|
||||
|
||||
out_tensor = Sum_3D.apply(A, -1, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
|
||||
A_master = A_master.clone()
|
||||
A_master.requires_grad = True
|
||||
C_master = torch.sum(A_master, dim=-1)
|
||||
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||||
C = torch.chunk(C, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} Sum forward: {}'.format(rank,
|
||||
check_equal(out_tensor, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
|
||||
out_tensor.backward(grad / DEPTH)
|
||||
|
||||
C_master.backward(grad_master)
|
||||
A_grad = A_master.grad
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
||||
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
||||
logger.info('Rank {} Sum backward: {}'.format(rank,
|
||||
check_equal(A_grad, A.grad)))
|
||||
|
||||
|
||||
def check_reduce():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_global_dist_logger()
|
||||
dtype = torch.float
|
||||
|
||||
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
||||
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
device = get_current_device()
|
||||
|
||||
# scaler
|
||||
B_shape = (DEPTH * DEPTH, DEPTH)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[k]
|
||||
B = torch.chunk(B, DEPTH, dim=0)[j]
|
||||
B = torch.squeeze(B)
|
||||
B = B.clone()
|
||||
B.requires_grad = True
|
||||
|
||||
out_scaler = Reduce_3D.apply(B, 0, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT)
|
||||
out_scaler = Reduce_3D.apply(out_scaler, 0, DEPTH,
|
||||
ParallelMode.PARALLEL_3D_INPUT)
|
||||
out_scaler = Reduce_3D.apply(out_scaler, 0, DEPTH,
|
||||
ParallelMode.PARALLEL_3D_WEIGHT)
|
||||
|
||||
B_master = B_master.clone()
|
||||
B_master.requires_grad = True
|
||||
D = torch.sum(B_master)
|
||||
logger.info('Rank {} Reduce forward: {}'.format(rank,
|
||||
check_equal(out_scaler,
|
||||
D)))
|
||||
|
||||
grad_shape = D.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
|
||||
out_scaler.backward(grad_master)
|
||||
|
||||
D.backward(grad_master)
|
||||
B_grad = B_master.grad
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
||||
B_grad = torch.squeeze(B_grad)
|
||||
logger.info('Rank {} Reduce backward: {}'.format(
|
||||
rank, check_equal(B_grad, B.grad)))
|
Reference in New Issue
Block a user