improved allgather & reducescatter for 3d

This commit is contained in:
zbian 2023-01-03 15:26:47 +08:00 committed by アマデウス
parent c719798abe
commit e94c79f15b
4 changed files with 43 additions and 29 deletions

View File

@ -3,12 +3,17 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ReduceOp
from torch import Tensor from torch import Tensor
from torch.distributed import ReduceOp
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
_all_gather_func = dist._all_gather_base \
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
_reduce_scatter_func = dist._reduce_scatter_base \
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
r"""Gathers all tensors from the parallel group and concatenates them in a r"""Gathers all tensors from the parallel group and concatenates them in a
@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
out = tensor out = tensor
work = None work = None
else: else:
shape = list(tensor.shape) tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
shape[0], shape[dim] = shape[dim], shape[0] out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:]
shape[0] *= depth tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
temp = list(torch.chunk(out, depth, dim=0))
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.all_gather(tensor_list=temp, work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op)
tensor=tensor.transpose(0, dim).contiguous(), out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
group=group,
async_op=async_op)
out = torch.transpose(out, 0, dim)
if async_op: if async_op:
return out, work return out, work
else: else:
@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor,
out = tensor out = tensor
work = None work = None
else: else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device) out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:]
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode) group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op) work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op)
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
if async_op: if async_op:
return out, work return out, work
else: else:
@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None: def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues r"""Modified from `torch.distributed.scatter_object_list
<https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
""" """
if dist.distributed_c10d._rank_not_in_group(group): if dist.distributed_c10d._rank_not_in_group(group):
return return

View File

@ -34,7 +34,7 @@ class _Linear3D(torch.autograd.Function):
ctx.output_parallel_mode = output_parallel_mode ctx.output_parallel_mode = output_parallel_mode
input_ = all_gather(input_, 0, input_parallel_mode) input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight, -1, weight_parallel_mode) weight = all_gather(weight, 0, weight_parallel_mode)
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight) output = torch.matmul(input_, weight)
@ -53,7 +53,7 @@ class _Linear3D(torch.autograd.Function):
weight_grad = torch.matmul( weight_grad = torch.matmul(
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1])) input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True) weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id) weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
input_op.wait() input_op.wait()
@ -205,7 +205,7 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
ctx.weight_id = weight_id ctx.weight_id = weight_id
input_ = all_gather(input_, 0, input_parallel_mode) input_ = all_gather(input_, 0, input_parallel_mode)
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode) weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1)
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight)
output = torch.matmul(input_, weight) output = torch.matmul(input_, weight)

View File

@ -196,8 +196,8 @@ class Linear3D(ParallelLayer):
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D) self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
self.depth = get_depth_from_env() self.depth = get_depth_from_env()
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.in_features_per_partition = divide(in_features, self.depth) self.in_features_per_partition = divide(in_features, self.depth**2)
self.out_features_per_partition = divide(out_features, self.depth**2) self.out_features_per_partition = divide(out_features, self.depth)
self.bias_features_per_partition = divide(out_features, self.depth) self.bias_features_per_partition = divide(out_features, self.depth)
self.weight = Parameter( self.weight = Parameter(
@ -287,7 +287,7 @@ class Linear3D(ParallelLayer):
local_state, local_state,
self.weight_parallel_mode, self.weight_parallel_mode,
dims={ dims={
weight_key: -1, weight_key: 0,
bias_key: 0 bias_key: 0
}, },
partition_states={ partition_states={
@ -310,7 +310,7 @@ class Linear3D(ParallelLayer):
local_state, local_state,
self.weight_parallel_mode, self.weight_parallel_mode,
dims={ dims={
weight_key: -1, weight_key: 0,
bias_key: 0 bias_key: 0
}, },
partition_states={ partition_states={

View File

@ -4,12 +4,23 @@
import time import time
import torch import torch
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context from colossalai.core import global_context
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, from colossalai.nn import (
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D, Classifier3D,
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D) CrossEntropyLoss3D,
Embedding3D,
LayerNorm3D,
Linear3D,
PatchEmbedding3D,
VanillaClassifier,
VanillaPatchEmbedding,
VocabParallelClassifier3D,
VocabParallelCrossEntropyLoss3D,
VocabParallelEmbedding3D,
)
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.utils import get_current_device, print_rank_0 from colossalai.utils import get_current_device, print_rank_0
@ -40,7 +51,7 @@ def check_linear():
torch.distributed.broadcast(weight_master, src=0) torch.distributed.broadcast(weight_master, src=0)
weight = torch.chunk(weight_master, DEPTH, dim=0)[k] weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
weight = torch.chunk(weight, DEPTH, dim=-1)[j] weight = torch.chunk(weight, DEPTH, dim=-1)[j]
weight = torch.chunk(weight, DEPTH, dim=-1)[i] weight = torch.chunk(weight, DEPTH, dim=0)[i]
layer.weight.data.copy_(weight) layer.weight.data.copy_(weight)
bias_master = layer_master.bias.data bias_master = layer_master.bias.data
torch.distributed.broadcast(bias_master, src=0) torch.distributed.broadcast(bias_master, src=0)
@ -93,7 +104,7 @@ def check_linear():
B_grad = layer_master.weight.grad.transpose(0, 1) B_grad = layer_master.weight.grad.transpose(0, 1)
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j] B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad))) logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
bias_grad = layer_master.bias.grad bias_grad = layer_master.bias.grad
@ -775,7 +786,7 @@ def check_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES) out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, device=device) out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0) torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i] out = torch.chunk(out_master, DEPTH, dim=0)[i]
@ -828,7 +839,7 @@ def check_vocab_parallel_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES) out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, device=device) out_master = torch.randn(out_shape, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0) torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i] out = torch.chunk(out_master, DEPTH, dim=0)[i]