improved allgather & reducescatter for 3d

This commit is contained in:
zbian
2023-01-03 15:26:47 +08:00
committed by アマデウス
parent c719798abe
commit e94c79f15b
4 changed files with 43 additions and 29 deletions

View File

@@ -4,12 +4,23 @@
import time
import torch
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context
from colossalai.logging import get_dist_logger
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D,
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D,
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D)
from colossalai.nn import (
Classifier3D,
CrossEntropyLoss3D,
Embedding3D,
LayerNorm3D,
Linear3D,
PatchEmbedding3D,
VanillaClassifier,
VanillaPatchEmbedding,
VocabParallelClassifier3D,
VocabParallelCrossEntropyLoss3D,
VocabParallelEmbedding3D,
)
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.utils import get_current_device, print_rank_0
@@ -40,7 +51,7 @@ def check_linear():
torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
weight = torch.chunk(weight, DEPTH, dim=0)[i]
layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0)
@@ -93,7 +104,7 @@ def check_linear():
B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad
@@ -775,7 +786,7 @@ def check_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
@@ -828,7 +839,7 @@ def check_vocab_parallel_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]