diff --git a/colossalai/communication/collective.py b/colossalai/communication/collective.py index bd5a984ae..7eeac241e 100644 --- a/colossalai/communication/collective.py +++ b/colossalai/communication/collective.py @@ -8,10 +8,13 @@ from torch import Tensor from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device -def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: +def all_gather(tensor: Tensor, + dim: int, + parallel_mode: ParallelMode, + on_cpu: bool = False, + async_op: bool = False) -> Tensor: r"""Gathers all tensors from the parallel group and concatenates them in a specific dimension. @@ -23,6 +26,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: tensor (:class:`torch.Tensor`): Tensor to be gathered. dim (int): The dimension concatenating in. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + on_cpu (bool, optional): Whether to communicate with Gloo backend. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -37,11 +41,12 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: shape = list(tensor.shape) shape[0], shape[dim] = shape[dim], shape[0] shape[0] *= depth - out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device()) + out = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) temp = list(torch.chunk(out, depth, dim=0)) + group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) work = dist.all_gather(tensor_list=temp, tensor=tensor.transpose(0, dim).contiguous(), - group=gpc.get_group(parallel_mode), + group=group, async_op=async_op) out = torch.transpose(out, 0, dim) if async_op: @@ -54,6 +59,7 @@ def reduce_scatter(tensor: Tensor, dim: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, + on_cpu: bool = False, async_op: bool = False) -> Tensor: r"""Reduces all tensors then scatters it in a specific dimension to all members in the parallel group. @@ -70,6 +76,7 @@ def reduce_scatter(tensor: Tensor, should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. More details about ReduceOp please refer to `ReduceOp `_. + on_cpu (bool, optional): Whether to communicate with Gloo backend. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -82,12 +89,9 @@ def reduce_scatter(tensor: Tensor, work = None else: temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim))) - out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device()) - work = dist.reduce_scatter(output=out, - input_list=temp, - op=op, - group=gpc.get_group(parallel_mode), - async_op=async_op) + out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=tensor.device) + group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) + work = dist.reduce_scatter(output=out, input_list=temp, op=op, group=group, async_op=async_op) if async_op: return out, work else: @@ -97,6 +101,7 @@ def reduce_scatter(tensor: Tensor, def all_reduce(tensor: Tensor, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, + on_cpu: bool = False, async_op: bool = False) -> Tensor: r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. @@ -111,6 +116,7 @@ def all_reduce(tensor: Tensor, should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. More details about ReduceOp please refer to `ReduceOp `_. + on_cpu (bool, optional): Whether to communicate with Gloo backend. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -123,14 +129,15 @@ def all_reduce(tensor: Tensor, work = None else: out = tensor.contiguous() - work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) + group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) + work = dist.all_reduce(out, op=op, group=group, async_op=async_op) if async_op: return out, work else: return out -def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): +def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, on_cpu: bool = False, async_op: bool = False): r"""Broadcast tensors to whole parallel group. Tensor must have the same number of elements in all processes participating in the collective. @@ -142,6 +149,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b tensor (:class:`torch.Tensor`): Tensor to be broadcast. src (int): Source rank. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + on_cpu (bool, optional): Whether to communicate with Gloo backend. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -154,14 +162,20 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b work = None else: out = tensor.contiguous() - work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op) + group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) + work = dist.broadcast(out, src=src, group=group, async_op=async_op) if async_op: return out, work else: return out -def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): +def reduce(tensor: Tensor, + dst: int, + parallel_mode: ParallelMode, + op: ReduceOp = ReduceOp.SUM, + on_cpu: bool = False, + async_op: bool = False): r"""Reduce tensors across whole parallel group. Only the process with rank ``dst`` is going to receive the final result. @@ -173,6 +187,7 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = tensor (:class:`torch.Tensor`): Tensor to be reduced. dst (int): Destination rank. parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. + on_cpu (bool, optional): Whether to communicate with Gloo backend. async_op (bool, optional): Whether operations are asynchronous. Returns: @@ -185,8 +200,62 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = work = None else: out = tensor.contiguous() - work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) + group = gpc.get_cpu_group(parallel_mode) if on_cpu else gpc.get_group(parallel_mode) + work = dist.reduce(out, dst=dst, op=op, group=group, async_op=async_op) if async_op: return out, work else: return out + + +def scatter_object_list(scatter_object_output_list, scatter_object_input_list, src=0, group=None): + r"""Modified from `torch.distributed.scatter_object_list ` to fix issues + """ + if dist._rank_not_in_group(group): + return + + if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1): + raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.") + + # set tensor device to cuda if backend is nccl + device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") + + my_rank = dist.get_rank() # use global rank + if my_rank == src: + tensor_list, tensor_sizes = zip( + *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]) + tensor_list = list(map(lambda x: x.to(device), tensor_list)) + tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes)) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + if my_rank == src: + max_tensor_size = max(tensor_sizes) + for tensor in tensor_list: + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long).to(device) + + dist.broadcast(max_tensor_size, src=src, group=group) + + # Scatter actual serialized objects + output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8).to(device) + dist.scatter( + output_tensor, + scatter_list=None if my_rank != src else tensor_list, + src=src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long).to(device) + dist.scatter( + obj_tensor_size, + scatter_list=None if my_rank != src else tensor_sizes, + src=src, + group=group, + ) + + output_tensor, obj_tensor_size = output_tensor.cpu(), obj_tensor_size.cpu() + # Deserialize back to object + scatter_object_output_list[0] = dist.distributed_c10d._tensor_to_object(output_tensor, obj_tensor_size)