mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-29 04:05:35 +00:00
improved allgather & reducescatter for 3d
This commit is contained in:
parent
c719798abe
commit
e94c79f15b
@ -3,12 +3,17 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ReduceOp
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
_all_gather_func = dist._all_gather_base \
|
||||||
|
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
|
||||||
|
_reduce_scatter_func = dist._reduce_scatter_base \
|
||||||
|
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
|
||||||
|
|
||||||
|
|
||||||
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
||||||
r"""Gathers all tensors from the parallel group and concatenates them in a
|
r"""Gathers all tensors from the parallel group and concatenates them in a
|
||||||
@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
|
|||||||
out = tensor
|
out = tensor
|
||||||
work = None
|
work = None
|
||||||
else:
|
else:
|
||||||
shape = list(tensor.shape)
|
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
|
||||||
shape[0], shape[dim] = shape[dim], shape[0]
|
out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:]
|
||||||
shape[0] *= depth
|
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
|
|
||||||
temp = list(torch.chunk(out, depth, dim=0))
|
|
||||||
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
|
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||||
work = dist.all_gather(tensor_list=temp,
|
work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op)
|
||||||
tensor=tensor.transpose(0, dim).contiguous(),
|
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
|
||||||
group=group,
|
|
||||||
async_op=async_op)
|
|
||||||
out = torch.transpose(out, 0, dim)
|
|
||||||
if async_op:
|
if async_op:
|
||||||
return out, work
|
return out, work
|
||||||
else:
|
else:
|
||||||
@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor,
|
|||||||
out = tensor
|
out = tensor
|
||||||
work = None
|
work = None
|
||||||
else:
|
else:
|
||||||
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
|
||||||
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
|
out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:]
|
||||||
|
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
|
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||||
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
|
work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op)
|
||||||
|
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
|
||||||
if async_op:
|
if async_op:
|
||||||
return out, work
|
return out, work
|
||||||
else:
|
else:
|
||||||
@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
|
|||||||
|
|
||||||
|
|
||||||
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
|
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
|
||||||
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
r"""Modified from `torch.distributed.scatter_object_list
|
||||||
|
<https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
||||||
"""
|
"""
|
||||||
if dist.distributed_c10d._rank_not_in_group(group):
|
if dist.distributed_c10d._rank_not_in_group(group):
|
||||||
return
|
return
|
||||||
|
@ -34,7 +34,7 @@ class _Linear3D(torch.autograd.Function):
|
|||||||
ctx.output_parallel_mode = output_parallel_mode
|
ctx.output_parallel_mode = output_parallel_mode
|
||||||
|
|
||||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||||
weight = all_gather(weight, -1, weight_parallel_mode)
|
weight = all_gather(weight, 0, weight_parallel_mode)
|
||||||
ctx.save_for_backward(input_, weight)
|
ctx.save_for_backward(input_, weight)
|
||||||
|
|
||||||
output = torch.matmul(input_, weight)
|
output = torch.matmul(input_, weight)
|
||||||
@ -53,7 +53,7 @@ class _Linear3D(torch.autograd.Function):
|
|||||||
|
|
||||||
weight_grad = torch.matmul(
|
weight_grad = torch.matmul(
|
||||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||||
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
weight_grad, op = reduce_scatter(weight_grad, 0, ctx.weight_parallel_mode, async_op=True)
|
||||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||||
|
|
||||||
input_op.wait()
|
input_op.wait()
|
||||||
@ -205,7 +205,7 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
|
|||||||
ctx.weight_id = weight_id
|
ctx.weight_id = weight_id
|
||||||
|
|
||||||
input_ = all_gather(input_, 0, input_parallel_mode)
|
input_ = all_gather(input_, 0, input_parallel_mode)
|
||||||
weight = all_gather(weight.transpose(0, 1), -1, weight_parallel_mode)
|
weight = all_gather(weight, 0, weight_parallel_mode).transpose(0, 1)
|
||||||
ctx.save_for_backward(input_, weight)
|
ctx.save_for_backward(input_, weight)
|
||||||
|
|
||||||
output = torch.matmul(input_, weight)
|
output = torch.matmul(input_, weight)
|
||||||
|
@ -196,8 +196,8 @@ class Linear3D(ParallelLayer):
|
|||||||
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
|
self.output_x_weight_parallel_mode = get_parallel_mode_from_env(OUTPUT_X_WEIGHT_3D)
|
||||||
self.depth = get_depth_from_env()
|
self.depth = get_depth_from_env()
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.in_features_per_partition = divide(in_features, self.depth)
|
self.in_features_per_partition = divide(in_features, self.depth**2)
|
||||||
self.out_features_per_partition = divide(out_features, self.depth**2)
|
self.out_features_per_partition = divide(out_features, self.depth)
|
||||||
self.bias_features_per_partition = divide(out_features, self.depth)
|
self.bias_features_per_partition = divide(out_features, self.depth)
|
||||||
|
|
||||||
self.weight = Parameter(
|
self.weight = Parameter(
|
||||||
@ -287,7 +287,7 @@ class Linear3D(ParallelLayer):
|
|||||||
local_state,
|
local_state,
|
||||||
self.weight_parallel_mode,
|
self.weight_parallel_mode,
|
||||||
dims={
|
dims={
|
||||||
weight_key: -1,
|
weight_key: 0,
|
||||||
bias_key: 0
|
bias_key: 0
|
||||||
},
|
},
|
||||||
partition_states={
|
partition_states={
|
||||||
@ -310,7 +310,7 @@ class Linear3D(ParallelLayer):
|
|||||||
local_state,
|
local_state,
|
||||||
self.weight_parallel_mode,
|
self.weight_parallel_mode,
|
||||||
dims={
|
dims={
|
||||||
weight_key: -1,
|
weight_key: 0,
|
||||||
bias_key: 0
|
bias_key: 0
|
||||||
},
|
},
|
||||||
partition_states={
|
partition_states={
|
||||||
|
@ -4,12 +4,23 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
|
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||||
from colossalai.core import global_context
|
from colossalai.core import global_context
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn import (Classifier3D, CrossEntropyLoss3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D,
|
from colossalai.nn import (
|
||||||
VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier3D,
|
Classifier3D,
|
||||||
VocabParallelCrossEntropyLoss3D, VocabParallelEmbedding3D)
|
CrossEntropyLoss3D,
|
||||||
|
Embedding3D,
|
||||||
|
LayerNorm3D,
|
||||||
|
Linear3D,
|
||||||
|
PatchEmbedding3D,
|
||||||
|
VanillaClassifier,
|
||||||
|
VanillaPatchEmbedding,
|
||||||
|
VocabParallelClassifier3D,
|
||||||
|
VocabParallelCrossEntropyLoss3D,
|
||||||
|
VocabParallelEmbedding3D,
|
||||||
|
)
|
||||||
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||||
from colossalai.utils import get_current_device, print_rank_0
|
from colossalai.utils import get_current_device, print_rank_0
|
||||||
|
|
||||||
@ -40,7 +51,7 @@ def check_linear():
|
|||||||
torch.distributed.broadcast(weight_master, src=0)
|
torch.distributed.broadcast(weight_master, src=0)
|
||||||
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
weight = torch.chunk(weight_master, DEPTH, dim=0)[k]
|
||||||
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
weight = torch.chunk(weight, DEPTH, dim=-1)[j]
|
||||||
weight = torch.chunk(weight, DEPTH, dim=-1)[i]
|
weight = torch.chunk(weight, DEPTH, dim=0)[i]
|
||||||
layer.weight.data.copy_(weight)
|
layer.weight.data.copy_(weight)
|
||||||
bias_master = layer_master.bias.data
|
bias_master = layer_master.bias.data
|
||||||
torch.distributed.broadcast(bias_master, src=0)
|
torch.distributed.broadcast(bias_master, src=0)
|
||||||
@ -93,7 +104,7 @@ def check_linear():
|
|||||||
B_grad = layer_master.weight.grad.transpose(0, 1)
|
B_grad = layer_master.weight.grad.transpose(0, 1)
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||||||
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||||||
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
logger.info('Rank {} linear backward (weight_grad): {}'.format(rank, check_equal(B_grad, layer.weight.grad)))
|
||||||
|
|
||||||
bias_grad = layer_master.bias.grad
|
bias_grad = layer_master.bias.grad
|
||||||
@ -775,7 +786,7 @@ def check_loss():
|
|||||||
|
|
||||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||||
out_master = torch.randn(out_shape, device=device)
|
out_master = torch.randn(out_shape, device=device)
|
||||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||||
torch.distributed.broadcast(out_master, src=0)
|
torch.distributed.broadcast(out_master, src=0)
|
||||||
torch.distributed.broadcast(target_master, src=0)
|
torch.distributed.broadcast(target_master, src=0)
|
||||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||||
@ -828,7 +839,7 @@ def check_vocab_parallel_loss():
|
|||||||
|
|
||||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||||
out_master = torch.randn(out_shape, device=device)
|
out_master = torch.randn(out_shape, device=device)
|
||||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||||
torch.distributed.broadcast(out_master, src=0)
|
torch.distributed.broadcast(out_master, src=0)
|
||||||
torch.distributed.broadcast(target_master, src=0)
|
torch.distributed.broadcast(target_master, src=0)
|
||||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||||
|
Loading…
Reference in New Issue
Block a user