improved allgather & reducescatter for 3d

This commit is contained in:
zbian
2023-01-03 15:26:47 +08:00
committed by アマデウス
parent c719798abe
commit e94c79f15b
4 changed files with 43 additions and 29 deletions

View File

@@ -3,12 +3,17 @@
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch import Tensor
from torch.distributed import ReduceOp
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
_all_gather_func = dist._all_gather_base \
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
_reduce_scatter_func = dist._reduce_scatter_base \
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
r"""Gathers all tensors from the parallel group and concatenates them in a
@@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
out = tensor
work = None
else:
shape = list(tensor.shape)
shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
temp = list(torch.chunk(out, depth, dim=0))
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:]
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.all_gather(tensor_list=temp,
tensor=tensor.transpose(0, dim).contiguous(),
group=group,
async_op=async_op)
out = torch.transpose(out, 0, dim)
work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op)
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
if async_op:
return out, work
else:
@@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor,
out = tensor
work = None
else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:]
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op)
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
if async_op:
return out, work
else:
@@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
r"""Modified from `torch.distributed.scatter_object_list
<https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
"""
if dist.distributed_c10d._rank_not_in_group(group):
return