mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
improved allgather & reducescatter for 3d
This commit is contained in:
@@ -3,12 +3,17 @@
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
from torch import Tensor
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
_all_gather_func = dist._all_gather_base \
|
||||
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
|
||||
_reduce_scatter_func = dist._reduce_scatter_base \
|
||||
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
|
||||
r"""Gathers all tensors from the parallel group and concatenates them in a
|
||||
@@ -33,17 +38,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
|
||||
out = tensor
|
||||
work = None
|
||||
else:
|
||||
shape = list(tensor.shape)
|
||||
shape[0], shape[dim] = shape[dim], shape[0]
|
||||
shape[0] *= depth
|
||||
out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device)
|
||||
temp = list(torch.chunk(out, depth, dim=0))
|
||||
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
|
||||
out_shape = (tensor_in.shape[0] * depth,) + tensor_in.shape[1:]
|
||||
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||
work = dist.all_gather(tensor_list=temp,
|
||||
tensor=tensor.transpose(0, dim).contiguous(),
|
||||
group=group,
|
||||
async_op=async_op)
|
||||
out = torch.transpose(out, 0, dim)
|
||||
work = _all_gather_func(tensor_out, tensor_in, group=group, async_op=async_op)
|
||||
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
|
||||
if async_op:
|
||||
return out, work
|
||||
else:
|
||||
@@ -81,10 +81,12 @@ def reduce_scatter(tensor: Tensor,
|
||||
out = tensor
|
||||
work = None
|
||||
else:
|
||||
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
|
||||
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device)
|
||||
tensor_in = tensor.contiguous() if dim == 0 else tensor.transpose(0, dim).contiguous()
|
||||
out_shape = (tensor_in.shape[0] // depth,) + tensor_in.shape[1:]
|
||||
tensor_out = torch.empty(out_shape, dtype=tensor.dtype, device=tensor.device)
|
||||
group = gpc.get_cpu_group(parallel_mode) if tensor.device.type == "cpu" else gpc.get_group(parallel_mode)
|
||||
work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op)
|
||||
work = _reduce_scatter_func(tensor_out, tensor_in, op=op, group=group, async_op=async_op)
|
||||
out = tensor_out if dim == 0 else tensor_out.transpose(0, dim)
|
||||
if async_op:
|
||||
return out, work
|
||||
else:
|
||||
@@ -193,7 +195,8 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
|
||||
|
||||
|
||||
def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None) -> None:
|
||||
r"""Modified from `torch.distributed.scatter_object_list <https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
||||
r"""Modified from `torch.distributed.scatter_object_list
|
||||
<https://pytorch.org/docs/stable/_modules/torch/distributed/distributed_c10d.html#scatter_object_list>` to fix issues
|
||||
"""
|
||||
if dist.distributed_c10d._rank_not_in_group(group):
|
||||
return
|
||||
|
Reference in New Issue
Block a user