mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
improved allgather & reducescatter for 3d
This commit is contained in:
@@ -4,12 +4,23 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.core import global_context
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D,
|
||||
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D,
|
||||
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D)
|
||||
from colossalai.nn import (
|
||||
Classifier3D,
|
||||
CrossEntropyLoss3D,
|
||||
Embedding3D,
|
||||
LayerNorm3D,
|
||||
Linear3D,
|
||||
PatchEmbedding3D,
|
||||
VanillaClassifier,
|
||||
VanillaPatchEmbedding,
|
||||
VocabParallelClassifier3D,
|
||||
VocabParallelCrossEntropyLoss3D,
|
||||
VocabParallelEmbedding3D,
|
||||
)
|
||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
@@ -40,7 +51,7 @@ def check_linear():
|
||||
torch.distributed.broadcast(weight_master, src=0)
|
||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
||||
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||
layer.weight.data.copy_(weight)
|
||||
bias_master = layer_master.bias.data
|
||||
torch.distributed.broadcast(bias_master, src=0)
|
||||
@@ -93,7 +104,7 @@ def check_linear():
|
||||
B_grad = layer_master.weight.grad.transpose(0, 1)
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||
|
||||
bias_grad = layer_master.bias.grad
|
||||
@@ -775,7 +786,7 @@ def check_loss():
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
@@ -828,7 +839,7 @@ def check_vocab_parallel_loss():
|
||||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
|
Reference in New Issue
Block a user