From e94c79f15b2ee1428ccda2cdb2d00007697c578c Mon Sep 17 00:00:00 2001 From: zbian Date: Tue, 3 Jan 2023 15:26:47 +0800 Subject: [PATCH] improved allgather & reducescatter for 3d --- colossalai/communication/collective.py | 33 ++++++++++--------- colossalai/nn/layer/parallel_3d/_operation.py | 6 ++-- colossalai/nn/layer/parallel_3d/layers.py | 8 ++--- .../test_3d/checks_3d/check_layer_3d.py | 25 ++++++++++---- 4 files changed, 43 insertions(+), 29 deletions(-) diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index 2c9e9927c..64fb5b8b5 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -3,12 +3,17 @@ import torch import torch.distributed as dist -from torch.distributed import ReduceOp from torch import Tensor +from torch.distributed import ReduceOp from colossalai.context import ParallelMode from colossalai.core import global_context as gpc +_all_gather_func = dist._all_gather_base \ + if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor +_reduce_scatter_func = dist._reduce_scatter_base \ + if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor + def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: r"""Gathers all tensors from the parallel group and concatenates them in a @@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: out = tensor work = None else: - shape = list(tensor.shape) - shape[0], shape[dim] = shape[dim], shape[0] - shape[0] *= depth - out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) - temp = list(torch.chunk(out, depth, dim=0)) + tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous() + out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:] + tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) - work = dist.all_gather(tensor_list=temp, - tensor=tensor.transpose(0, dim).contiguous(), - group=group, - async_op=async_op) - out = torch.transpose(out, 0, dim) + work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op) + out = tensor_out if dim == 0 else tensor_out.transpose(0, dim) if async_op: return out, work else: @@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor, out = tensor work = None else: - temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) - out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device) + tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous() + out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:] + tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) - work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op) + work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op) + out = tensor_out if dim == 0 else tensor_out.transpose(0, dim) if async_op: return out, work else: @@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None: - r"""Modified from `torch.distributed.scatter_object_list ` to fix issues + r"""Modified from `torch.distributed.scatter_object_list + ` to fix issues """ if dist.distributed_c10d._rank_not_in_group(group): return diff --git a/colossalai/nn/layer/parallel_3d/_operation.py b/colossalai/nn/layer/parallel_3d/_operation.py index 07869e5ad..5dc9a2428 100755 --- a/colossalai/nn/layer/parallel_3d/_operation.py +++ b/colossalai/nn/layer/parallel_3d/_operation.py @@ -34,7 +34,7 @@ class _Linear3D(torch.autograd.Function): ctx.output_parallel_mode = output_parallel_mode input_ = all_gather(input_, 0, input_parallel_mode) - weight = all_gather(weight, -1, weight_parallel_mode) + weight = all_gather(weight, 0, weight_parallel_mode) ctx.save_for_backward(input_, weight) output = torch.matmul(input_, weight) @@ -53,7 +53,7 @@ class _Linear3D(torch.autograd.Function): weight_grad = torch.matmul( input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) - weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True) + weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) input_op.wait() @@ -205,7 +205,7 @@ class _VocabParallelClassifier3D(torch.autograd.Function): ctx.weight_id = weight_id input_ = all_gather(input_, 0, input_parallel_mode) - weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode) + weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1) ctx.save_for_backward(input_, weight) output = torch.matmul(input_, weight) diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 0a1db6800..99b0c3f8b 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -196,8 +196,8 @@ class Linear3D(ParallelLayer): self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) self.depth = get_depth_from_env() self.skip_bias_add = skip_bias_add - self.in_features_per_partition = divide(in_features, self.depth) - self.out_features_per_partition = divide(out_features, self.depth**2) + self.in_features_per_partition = divide(in_features, self.depth**2) + self.out_features_per_partition = divide(out_features, self.depth) self.bias_features_per_partition = divide(out_features, self.depth) self.weight = Parameter( @@ -287,7 +287,7 @@ class Linear3D(ParallelLayer): local_state, self.weight_parallel_mode, dims={ - weight_key: -1, + weight_key: 0, bias_key: 0 }, partition_states={ @@ -310,7 +310,7 @@ class Linear3D(ParallelLayer): local_state, self.weight_parallel_mode, dims={ - weight_key: -1, + weight_key: 0, bias_key: 0 }, partition_states={ diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py index 9e199e22e..e946a1f59 100644 --- a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -4,12 +4,23 @@ import time import torch + from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.core import global_context from colossalai.logging import get_dist_logger -from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, - VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D, - VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D) +from colossalai.nn import ( + Classifier3D, + CrossEntropyLoss3D, + Embedding3D, + LayerNorm3D, + Linear3D, + PatchEmbedding3D, + VanillaClassifier, + VanillaPatchEmbedding, + VocabParallelClassifier3D, + VocabParallelCrossEntropyLoss3D, + VocabParallelEmbedding3D, +) from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.utils import get_current_device, print_rank_0 @@ -40,7 +51,7 @@ def check_linear(): torch.distributed.broadcast(weight_master, src=0) weight = torch.chunk(weight_master, DEPTH, dim=0)[k] weight = torch.chunk(weight, DEPTH, dim=-1)[j] - weight = torch.chunk(weight, DEPTH, dim=-1)[i] + weight = torch.chunk(weight, DEPTH, dim=0)[i] layer.weight.data.copy_(weight) bias_master = layer_master.bias.data torch.distributed.broadcast(bias_master, src=0) @@ -93,7 +104,7 @@ def check_linear(): B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] - B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) bias_grad = layer_master.bias.grad @@ -775,7 +786,7 @@ def check_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i] @@ -828,7 +839,7 @@ def check_vocab_parallel_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i]